From 8e0d2fd7925c0de9703d666ea2cc004327f6e544 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 7 Apr 2022 19:29:39 -0700 Subject: [PATCH] fix mask for linear attention --- flash_pytorch/flash_pytorch.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/flash_pytorch/flash_pytorch.py b/flash_pytorch/flash_pytorch.py index f874e0c..954d9dd 100644 --- a/flash_pytorch/flash_pytorch.py +++ b/flash_pytorch/flash_pytorch.py @@ -284,7 +284,8 @@ def forward( # mask out linear attention keys if exists(mask): - lin_k = lin_k.masked_fill(~mask, 0.) + lin_mask = rearrange(mask, '... -> ... 1') + lin_k = lin_k.masked_fill(~lin_mask, 0.) # rotate queries and keys diff --git a/setup.py b/setup.py index 1c57a47..210182f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'FLASH-pytorch', packages = find_packages(exclude=[]), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'FLASH - Transformer Quality in Linear Time - Pytorch', author = 'Phil Wang',