Skip to content

Commit

Permalink
Merge pull request #90 from dmmiller612/bert-last-4
Browse files Browse the repository at this point in the history
Added 2 new aggregate algorithms
  • Loading branch information
dmmiller612 authored Dec 23, 2020
2 parents c5777c6 + feab2df commit 78311a9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from setuptools import find_packages

setup(name='bert-extractive-summarizer',
version='0.5.1',
version='0.5.2',
description='Extractive Text Summarization with BERT',
keywords = ['bert', 'pytorch', 'machine learning', 'deep learning', 'extractive summarization', 'summary'],
long_description=open("README.md", "r", encoding='utf-8').read(),
long_description_content_type="text/markdown",
url='https://github.com/dmmiller612/bert-extractive-summarizer',
download_url='https://github.com/dmmiller612/bert-extractive-summarizer/archive/0.5.1.tar.gz',
download_url='https://github.com/dmmiller612/bert-extractive-summarizer/archive/0.5.2.tar.gz',
author='Derek Miller',
author_email='[email protected]',
install_requires=['transformers', 'scikit-learn', 'spacy'],
Expand Down
17 changes: 13 additions & 4 deletions summarizer/bert_parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def extract_embeddings(
:param hidden: The hidden layer to use for a readout handler
:param squeeze: If we should squeeze the outputs (required for some layers)
:param reduce_option: How we should reduce the items.
:return: A numpy array.
:return: A torch vector.
"""

tokens_tensor = self.tokenize_input(text)
Expand All @@ -84,13 +84,22 @@ def extract_embeddings(
if -1 > hidden > -12:

if reduce_option == 'max':
pooled = hidden_states[hidden].max(dim=1)[0]
pooled = hidden_states[hidden].max(dim=1)[0].squeeze()

elif reduce_option == 'median':
pooled = hidden_states[hidden].median(dim=1)[0]
pooled = hidden_states[hidden].median(dim=1)[0].squeeze()

elif reduce_option == 'concat_last_4':
last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
cat_hidden_states = torch.cat(tuple(last_4), dim=-1)
pooled = torch.mean(cat_hidden_states, dim=1).squeeze()

elif reduce_option == 'reduce_last_4':
last_4 = [hidden_states[i] for i in (-1, -2, -3, -4)]
pooled = torch.cat(tuple(last_4), dim=1).mean(axis=1).squeeze()

else:
pooled = hidden_states[hidden].mean(dim=1)
pooled = hidden_states[hidden].mean(dim=1).squeeze()

return pooled

Expand Down

0 comments on commit 78311a9

Please sign in to comment.