Skip to content

Commit

Permalink
fix an error thanks to @ShomyLiu at #2
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 18, 2022
1 parent 8e0d2fd commit 0cb9473
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions flash_pytorch/flash_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ def forward(

lin_out = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
else:
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
lin_out = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / n
lin_out = einsum('b g n d, b g d e -> b g n e', lin_q, lin_kv)

# fold back groups into full sequence, and excise out padding

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
setup(
name = 'FLASH-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'FLASH - Transformer Quality in Linear Time - Pytorch',
author = 'Phil Wang',
author_email = '[email protected]',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/FLASH-pytorch',
keywords = [
'artificial intelligence',
Expand Down

0 comments on commit 0cb9473

Please sign in to comment.