Skip to content

Commit

Permalink
Merge pull request #91 from dmmiller612/new_summary_items
Browse files Browse the repository at this point in the history
Added new summary items
  • Loading branch information
dmmiller612 authored Dec 28, 2020
2 parents 78311a9 + 825b0e1 commit e70a508
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 53 deletions.
77 changes: 51 additions & 26 deletions summarizer/bert_parent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Union

import numpy as np
import torch
Expand Down Expand Up @@ -61,72 +61,97 @@ def tokenize_input(self, text: str) -> torch.tensor:
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
return torch.tensor([indexed_tokens]).to(self.device)

def _pooled_handler(self, hidden: torch.Tensor, reduce_option: str) -> torch.Tensor:
"""
Handles torch tensor.
:param hidden: The hidden torch tensor to process.
:param reduce_option: The reduce option to use, such as mean, etc.
:return: Returns a torch tensor.
"""

if reduce_option == 'max':
return hidden.max(dim=1)[0].squeeze()

elif reduce_option == 'median':
return hidden.median(dim=1)[0].squeeze()

return hidden.mean(dim=1).squeeze()

def extract_embeddings(
self,
text: str,
hidden: int=-2,
reduce_option: str ='mean'
hidden: Union[List[int], int] = -2,
reduce_option: str ='mean',
hidden_concat: bool = False
) -> torch.Tensor:

"""
Extracts the embeddings for the given text
:param text: The text to extract embeddings for.
:param hidden: The hidden layer to use for a readout handler
:param hidden: The hidden layer(s) to use for a readout handler
:param squeeze: If we should squeeze the outputs (required for some layers)
:param reduce_option: How we should reduce the items.
:param hidden_concat: Whether or not to concat multiple hidden layers.
:return: A torch vector.
"""

tokens_tensor = self.tokenize_input(text)
pooled, hidden_states = self.model(tokens_tensor)[-2:]

if -1 > hidden > -12:

if reduce_option == 'max':
pooled = hidden_states[hidden].max(dim=1)[0].squeeze()
# deprecated temporary keyword functions.
if reduce_option == 'concat_last_4':
last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
cat_hidden_states = torch.cat(tuple(last_4), dim=-1)
return torch.mean(cat_hidden_states, dim=1).squeeze()

elif reduce_option == 'median':
pooled = hidden_states[hidden].median(dim=1)[0].squeeze()
elif reduce_option == 'reduce_last_4':
last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
return torch.cat(tuple(last_4), dim=1).mean(axis=1).squeeze()

elif reduce_option == 'concat_last_4':
last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
cat_hidden_states = torch.cat(tuple(last_4), dim=-1)
pooled = torch.mean(cat_hidden_states, dim=1).squeeze()
elif type(hidden) == int:
hidden_s = hidden_states[hidden]
return self._pooled_handler(hidden_s, reduce_option)

elif reduce_option == 'reduce_last_4':
last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
pooled = torch.cat(tuple(last_4), dim=1).mean(axis=1).squeeze()
elif hidden_concat:
last_states = [hidden_states[i] for i in hidden]
cat_hidden_states = torch.cat(tuple(last_states), dim=-1)
return torch.mean(cat_hidden_states, dim=1).squeeze()

else:
pooled = hidden_states[hidden].mean(dim=1).squeeze()
last_states = [hidden_states[i] for i in hidden]
hidden_s = torch.cat(tuple(last_states), dim=1)

return pooled
return self._pooled_handler(hidden_s, reduce_option)

def create_matrix(
self,
content: List[str],
hidden: int=-2,
reduce_option: str = 'mean'
hidden: Union[List[int], int] = -2,
reduce_option: str = 'mean',
hidden_concat: bool = False
) -> ndarray:
"""
Create matrix from the embeddings
:param content: The list of sentences
:param hidden: Which hidden layer to use
:param reduce_option: The reduce option to run.
:param hidden_concat: Whether or not to concat multiple hidden layers.
:return: A numpy array matrix of the given content.
"""

return np.asarray([
np.squeeze(self.extract_embeddings(t, hidden=hidden, reduce_option=reduce_option).data.cpu().numpy())
for t in content
np.squeeze(self.extract_embeddings(
t, hidden=hidden, reduce_option=reduce_option, hidden_concat=hidden_concat
).data.cpu().numpy()) for t in content
])

def __call__(
self,
content: List[str],
hidden: int= -2,
reduce_option: str = 'mean'
reduce_option: str = 'mean',
hidden_concat: bool = False
) -> ndarray:
return self.create_matrix(content, hidden, reduce_option)
return self.create_matrix(content, hidden, reduce_option, hidden_concat)
8 changes: 7 additions & 1 deletion summarizer/cluster_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __get_model(self, k: int):
:param k: amount of clusters
:return: Clustering model
"""

if self.algorithm == 'gmm':
Expand Down Expand Up @@ -112,4 +111,11 @@ def cluster(self, ratio: float = 0.1, num_sentences: int = None) -> List[int]:
return sorted_values

def __call__(self, ratio: float = 0.1, num_sentences: int = None) -> List[int]:
"""
Clusters sentences based on the ratio
:param ratio: Ratio to use for clustering
:param num_sentences: Number of sentences. Overrides ratio.
:return: Sentences index that qualify for summary
"""

return self.cluster(ratio)
50 changes: 25 additions & 25 deletions summarizer/model_processors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
from transformers import *
Expand All @@ -22,21 +22,24 @@ def __init__(
model: str = 'bert-large-uncased',
custom_model: PreTrainedModel = None,
custom_tokenizer: PreTrainedTokenizer = None,
hidden: int = -2,
hidden: Union[List[int], int] = -2,
reduce_option: str = 'mean',
sentence_handler: SentenceHandler = SentenceHandler(),
random_state: int = 12345
random_state: int = 12345,
hidden_concat: bool = False
):
"""
This is the parent Bert Summarizer model. New methods should implement this class
:param model: This parameter is associated with the inherit string parameters from the transformers library.
:param custom_model: If you have a pre-trained model, you can add the model class here.
:param custom_tokenizer: If you have a custom tokenizer, you can add the tokenizer here.
:param hidden: This signifies which layer of the BERT model you would like to use as embeddings.
:param hidden: This signifies which layer(s) of the BERT model you would like to use as embeddings.
:param reduce_option: Given the output of the bert model, this param determines how you want to reduce results.
:param sentence_handler: The handler to process sentences. If want to use coreference, instantiate and pass CoreferenceHandler instance
:param sentence_handler: The handler to process sentences. If want to use coreference, instantiate and pass
CoreferenceHandler instance
:param random_state: The random state to reproduce summarizations.
:param hidden_concat: Whether or not to concat multiple hidden layers.
"""

np.random.seed(random_state)
Expand All @@ -45,19 +48,7 @@ def __init__(
self.reduce_option = reduce_option
self.sentence_handler = sentence_handler
self.random_state = random_state

def process_content_sentences(self, body: str, min_length: int = 40, max_length: int = 600) -> List[str]:
"""
Processes the content sentences with neural coreference.
:param body: The raw string body to process
:param min_length: Minimum length that the sentences must be
:param max_length: Max length that the sentences mus fall under
:return: Returns a list of sentences with coreference applied.
"""

doc = self.nlp(body)._.coref_resolved
doc = self.nlp(doc)
return [c.string.strip() for c in doc.sents if max_length > len(c.string.strip()) > min_length]
self.hidden_concat = hidden_concat

def cluster_runner(
self,
Expand All @@ -81,7 +72,7 @@ def cluster_runner(
if num_sentences is not None:
num_sentences = num_sentences if use_first else num_sentences

hidden = self.model(content, self.hidden, self.reduce_option)
hidden = self.model(content, self.hidden, self.reduce_option, hidden_concat=self.hidden_concat)
hidden_args = ClusterFeatures(hidden, algorithm, random_state=self.random_state).cluster(ratio, num_sentences)

if use_first:
Expand Down Expand Up @@ -241,10 +232,11 @@ def __init__(
model: str = 'bert-large-uncased',
custom_model: PreTrainedModel = None,
custom_tokenizer: PreTrainedTokenizer = None,
hidden: int = -2,
hidden: Union[List[int], int] = -2,
reduce_option: str = 'mean',
sentence_handler: SentenceHandler = SentenceHandler(),
random_state: int = 12345
random_state: int = 12345,
hidden_concat: bool = False
):
"""
This is the main Bert Summarizer class.
Expand All @@ -257,15 +249,20 @@ def __init__(
:param greedyness: associated with the neuralcoref library. Determines how greedy coref should be.
:param language: Which language to use for training.
:param random_state: The random state to reproduce summarizations.
:param hidden_concat: Whether or not to concat multiple hidden layers.
"""

super(Summarizer, self).__init__(
model, custom_model, custom_tokenizer, hidden, reduce_option, sentence_handler, random_state
model, custom_model, custom_tokenizer, hidden, reduce_option, sentence_handler, random_state, hidden_concat
)


class TransformerSummarizer(ModelProcessor):

"""
Newer style that has keywords for models and tokenizers, but allows the user to change the type.
"""

MODEL_DICT = {
'Bert': (BertModel, BertTokenizer),
'OpenAIGPT': (OpenAIGPTModel, OpenAIGPTTokenizer),
Expand All @@ -282,16 +279,19 @@ def __init__(
transformer_type: str = 'Bert',
transformer_model_key: str = 'bert-base-uncased',
transformer_tokenizer_key: str = None,
hidden: int = -2,
hidden: Union[List[int], int] = -2,
reduce_option: str = 'mean',
sentence_handler: SentenceHandler = SentenceHandler(),
random_state: int = 12345
random_state: int = 12345,
hidden_concat: bool = False
):

try:
self.MODEL_DICT['Roberta'] = (RobertaModel, RobertaTokenizer)
self.MODEL_DICT['Albert'] = (AlbertModel, AlbertTokenizer)
self.MODEL_DICT['Camembert'] = (CamembertModel, CamembertTokenizer)
self.MODEL_DICT['Bart'] = (BartModel, BartTokenizer)
self.MODEL_DICT['Longformer'] = (LongformerModel, LongformerTokenizer)
except Exception as e:
pass # older transformer version

Expand All @@ -303,5 +303,5 @@ def __init__(
)

super().__init__(
None, model, tokenizer, hidden, reduce_option, sentence_handler, random_state
None, model, tokenizer, hidden, reduce_option, sentence_handler, random_state, hidden_concat
)
18 changes: 17 additions & 1 deletion tests/test_summary_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def summarizer():
return Summarizer('distilbert-base-uncased')


@pytest.fixture()
def summarizer_multi_hidden():
return Summarizer('distilbert-base-uncased', hidden=[-1,-2,-3])


@pytest.fixture()
def passage():
return '''
Expand Down Expand Up @@ -45,8 +50,19 @@ def passage():
'''


def test_multi_hidden(summarizer_multi_hidden, passage):
res = summarizer_multi_hidden(passage, num_sentences=5, min_length=40, max_length=500)
assert len(res) > 10


def test_multi_hidden_concat(summarizer_multi_hidden: Summarizer, passage):
summarizer_multi_hidden.hidden_concat = True
res = summarizer_multi_hidden(passage, num_sentences=5, min_length=40, max_length=500)
assert len(res) > 10


def test_summary_creation(summarizer, passage):
res = summarizer(passage, ratio=0.15, min_length=25, max_length=500)
res = summarizer(passage, ratio=0.15, min_length=40, max_length=500)
assert len(res) > 10


Expand Down

0 comments on commit e70a508

Please sign in to comment.