diff --git a/flash_pytorch/flash_pytorch.py b/flash_pytorch/flash_pytorch.py index f874e0c..954d9dd 100644 --- a/flash_pytorch/flash_pytorch.py +++ b/flash_pytorch/flash_pytorch.py @@ -284,7 +284,8 @@ def forward( # mask out linear attention keys if exists(mask): - lin_k = lin_k.masked_fill(~mask, 0.) + lin_mask = rearrange(mask, '... -> ... 1') + lin_k = lin_k.masked_fill(~lin_mask, 0.) # rotate queries and keys diff --git a/setup.py b/setup.py index 1c57a47..210182f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'FLASH-pytorch', packages = find_packages(exclude=[]), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'FLASH - Transformer Quality in Linear Time - Pytorch', author = 'Phil Wang',