-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attempt to unify perf kernel and AOTriton's kernel. #716
base: main_perf
Are you sure you want to change the base?
Conversation
Note: the performance downgrades to 300TFLOPS after these changes that seemingly cosmetic |
Confirmed the hipGraph support causes huge performance degeneration. Even if the code path is guarded with ENABLE_DROPOUT |
Changed the UT in a4e6837 This is required by LLAMA 3.2. which has different values in different batches of the bias tensor. Changes in related projects: |
539ae8b
to
7c33dd0
Compare
@jayfurmanek @vgokhale All UTs passed |
…AMIC According to ROCm/triton#716 this can restore the perf to ~400T
Can you post performance before and after these changes for both the model and default cmd line args? |
@vgokhale here it is |
## Major Changes * [kernel] Backport the 2025/01/28 `main_perf` kernel + [kernel] Support empty L tensor pointer. + See ROCm/triton#716 for more details + [shim] Adjust the build rules accordingly * [shim] Remove non-power-of-two (NPOT) head dim 72, which triggers compiler bugs on bf16 * [db] Remove `attn_fwd` table from tuning database, since the old entries are not valid anymore. * [db] Set all entries to `num_stages=1` since `num_stages=2` constantly trigger compiler bugs * [test] Add new head dimensions. Now categories into three groups + Power-of-two head dimensions + Optimized NPOT head dimensions + Prime number head dimensions to cover all gaps b/w neighboring POT+NPOT head dims. * [shim] Add env var `AOTRITON_SKIP_LUT_CHECK` to skip LUT sanity check on certain kernels + As of this PR, AOTriton must be built with `AOTRITON_SKIP_LUT_CHECK=flash.attn_fwd ninja install` ## Minor Changes * [build] Bump the version number to 0.9.0. (Should be done at the beginning of 0.9 dev) * [API] In the API, move bias tensor to the position immediately after v tensor, matching the kernel argument order * [shim] Add `TensorView<0>::get_null_tensor` * [test] Change `AttentionExtraArgs` from namedtuple to dataclass for easier-to-read default values. * [mptune] Change output json format to match kernel argument changes. * [test] Use cpu reference when seqlen_k == 579 (used by `test_gqa` tests). GPU reference triggers segfault. * [test] Change default value_fudge_factor to 36.0 (Should be 40.0 if considering GQA tests) * [shim] Fix the code path when the tuning database is not available ## Know Problems * Tuning database for `flash.attn_fwd` kernel is cleared and no plan to re-build it ATM due to immediate additional changes to the forward kernel.
It is difficult to let yapf format those long functions in a readable way. So I disabled the formatting of those functions.
This stops us from using Triton's type annotations like '*fp16'.
Check-File-Changes complains
For some reason this causes performance problems.
In the inner loop, use PERSISTENT and PERSISTENT_DYNAMIC derived from PERSISTENT_TYPE. This seems working better.
See SWDEV-508774 for details.
Test passed with: pytest flash-attention.py -k 'test_op_fwd_bias[dtype0-False-True-True-2-4-16384-16384-128]' If there are still concerns, we can add type annotation tl.uint64 to strides.
fd4442b
to
4d2dd42
Compare
Major changes
Constant_in_jit_but_variable_in_aot
constexpr_or_i32
/constexpr_or_f32
/constexpr_or_bool
annotation for such typesmstore2d
to handle OOB writes of encoded softmaxMax_seqlen_k
in calculating the Philox offsets, avoiding overlapped offsets for Varlen.PERSISTENT
andPERSISTENT_DYNAMIC
withPERSISTENT_TYPE = 0/1/2
for No, Fixed and Dynamic optionsNum_seqlens
to support padded varlen (Still Rank-4 Tensor, but only part of sequences are used)Num_seqlens == 0
: Classical SDPANum_seqlens > 0
: Conventional VarlenNum_seqlens < 0
: Padded VarlenIS_CAUSAL
withCAUSAL_TYPE
:CAUSAL_TYPE == 0
: No causalCAUSAL_TYPE == 1
: Top-left alignment (PyTorch ME backend's default)CAUSAL_TYPE == 2
: Bottom-right alignment (main_perf's previous settings)tl.uint64
to all strides except for alibi, in order to support large tensor.Unit Test Changes
torch.expand
.tl.uint64
type annotation on strides is used to overcome this limit