Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to unify perf kernel and AOTriton's kernel. #716

Open
wants to merge 17 commits into
base: main_perf
Choose a base branch
from

Conversation

xinyazhang
Copy link

@xinyazhang xinyazhang commented Jan 29, 2025

Major changes

  • Introduce a new naming convention: Constant_in_jit_but_variable_in_aot
    • Please use constexpr_or_i32/constexpr_or_f32/constexpr_or_bool annotation for such types
  • Add mstore2d to handle OOB writes of encoded softmax
  • Use Max_seqlen_k in calculating the Philox offsets, avoiding overlapped offsets for Varlen.
  • Replace PERSISTENT and PERSISTENT_DYNAMIC with PERSISTENT_TYPE = 0/1/2 for No, Fixed and Dynamic options
  • Add Num_seqlens to support padded varlen (Still Rank-4 Tensor, but only part of sequences are used)
    • Num_seqlens == 0: Classical SDPA
    • Num_seqlens > 0: Conventional Varlen
    • Num_seqlens < 0: Padded Varlen
  • Replace IS_CAUSAL with CAUSAL_TYPE:
    • CAUSAL_TYPE == 0: No causal
    • CAUSAL_TYPE == 1: Top-left alignment (PyTorch ME backend's default)
    • CAUSAL_TYPE == 2: Bottom-right alignment (main_perf's previous settings)
  • Fix Nan when sm_scale=0 on certain bias pattern (See: [Documentation]: Forward kernel returns NaN when inputs are irregular (including causal) and sm_scale is 0 aotriton#47)
  • Rearrange order of arguments. Please only add one argument for each line, with two exceptions:
    • strides of the same tensor should be put in a single line.
    • Basic SPDA arguments and all tensors can be compactly packed together.
  • Add hipGraph support to PRNG
  • Add type annotation tl.uint64 to all strides except for alibi, in order to support large tensor.

Unit Test Changes

  • Change the Dimension 0 of bias tensor from 1 to actual batch size through torch.expand.
  • Add test for bias tensor with real batches, i.e., allocate actual memory and fill different values in different batches
    • This is required by LLAMA 3.2
  • Remove the assertion of bias tensor size
    • tl.uint64 type annotation on strides is used to overcome this limit

@xinyazhang
Copy link
Author

Note: the performance downgrades to 300TFLOPS after these changes that seemingly cosmetic
Will dig the reason.

@xinyazhang
Copy link
Author

xinyazhang commented Jan 29, 2025

Confirmed the hipGraph support causes huge performance degeneration. Even if the code path is guarded with ENABLE_DROPOUT

@xinyazhang xinyazhang marked this pull request as ready for review January 30, 2025 22:54
@xinyazhang
Copy link
Author

xinyazhang commented Jan 31, 2025

Changed the UT in a4e6837

This is required by LLAMA 3.2. which has different values in different batches of the bias tensor.
Ignoring batch index causes SWDEV-508774

Changes in related projects:

@xinyazhang xinyazhang force-pushed the xinyazhang/union-fa-naming_convention branch 2 times, most recently from 539ae8b to 7c33dd0 Compare January 31, 2025 16:57
@xinyazhang
Copy link
Author

@jayfurmanek @vgokhale All UTs passed

xinyazhang added a commit to ROCm/aotriton that referenced this pull request Feb 3, 2025
@vgokhale
Copy link
Collaborator

vgokhale commented Feb 4, 2025

Can you post performance before and after these changes for both the model and default cmd line args?

@xinyazhang
Copy link
Author

xinyazhang commented Feb 4, 2025

og-flash-attention.py is the original FA file I checked out to make the comparison easier.

(py_3.10) xinyazha@1949bd361800:~/triton-rocm/python/perf-kernels$ sha256sum flash-attention.py og-flash-attention.py 
4754219a53db54c2eea70fc5cb94bf5fce393c286aa730e862734e2ad33ee240  flash-attention.py
244718ba8927a15aa11ea14f5b34af0c1fb0a923a8965a87c0650092bed37397  og-flash-attention.py
(py_3.10) xinyazha@1949bd361800:~/triton-rocm/python/perf-kernels$ python flash-attention.py 
fused-attention-fwd-d128-layoutbhsd:
    BATCH    HQ    HK  N_CTX_Q  N_CTX_K      triton       torch
0    16.0  16.0  16.0   1024.0   1024.0  290.451385  209.159602
1     8.0  16.0  16.0   2048.0   2048.0  348.721926  262.840479
2     4.0  16.0  16.0   4096.0   4096.0  370.112906  285.095079
3     2.0  16.0  16.0   8192.0   8192.0  381.123075  293.967613
4     1.0  16.0  16.0  16384.0  16384.0  378.565994  281.094768
5     2.0  48.0  48.0   1024.0   1024.0  253.586024  199.107892
6     2.0  48.0  48.0   2048.0   1024.0  308.630244  228.837054
7     2.0  48.0  48.0   4096.0   8192.0  362.719665  286.586611
8     2.0  48.0  48.0   8192.0   4096.0  378.872578  301.687999
9     2.0  48.0  48.0  16384.0   8192.0  390.812812  316.858139
10    8.0  16.0  16.0   1989.0  15344.0  290.809432  262.368374
11    4.0  16.0  16.0   4097.0    163.0  133.408037   91.389727
12    2.0  16.0  16.0   8122.0   2159.0  275.348037  264.067166
13    1.0  16.0  16.0  16281.0      7.0    6.050914    6.945960
14    2.0  48.0  48.0   1021.0   1020.0  210.570050  192.480003
15    2.0  48.0  48.0   2001.0   2048.0  323.025914  254.007358
16    2.0  48.0  48.0   3996.0   9639.0  292.943909  274.532452
17    2.0  48.0  48.0   8181.0   1021.0  256.953097  259.333202
(py_3.10) xinyazha@1949bd361800:~/triton-rocm/python/perf-kernels$ python og-flash-attention.py 
fused-attention-fwd-d128-layoutbhsd:
    BATCH    HQ    HK  N_CTX_Q  N_CTX_K      triton       torch
0    16.0  16.0  16.0   1024.0   1024.0  289.643802  209.040048
1     8.0  16.0  16.0   2048.0   2048.0  348.803770  262.284504
2     4.0  16.0  16.0   4096.0   4096.0  371.181808  284.123233
3     2.0  16.0  16.0   8192.0   8192.0  381.649394  294.678136
4     1.0  16.0  16.0  16384.0  16384.0  380.010986  281.142287
5     2.0  48.0  48.0   1024.0   1024.0  252.326097  198.872377
6     2.0  48.0  48.0   2048.0   1024.0  306.027914  228.772620
7     2.0  48.0  48.0   4096.0   8192.0  364.630607  286.723525
8     2.0  48.0  48.0   8192.0   4096.0  379.244673  301.998962
9     2.0  48.0  48.0  16384.0   8192.0  392.781918  316.020913
10    8.0  16.0  16.0   1989.0  15344.0  281.440980  261.661028
11    4.0  16.0  16.0   4097.0    163.0  136.364422   91.629028
12    2.0  16.0  16.0   8122.0   2159.0  264.899686  264.636067
13    1.0  16.0  16.0  16281.0      7.0    6.896461    6.926344
14    2.0  48.0  48.0   1021.0   1020.0  210.527696  191.724849
15    2.0  48.0  48.0   2001.0   2048.0  322.386144  254.093665
16    2.0  48.0  48.0   3996.0   9639.0  282.014099  275.197678
17    2.0  48.0  48.0   8181.0   1021.0  256.838505  260.254701
(py_3.10) xinyazha@1949bd361800:~/triton-rocm/python/perf-kernels$ python flash-attention.py -model all
fused-attention-fwd-layoutbhsd:
         model  BATCH   HQ  HK  N_CTX_Q  N_CTX_K  D_HEAD      triton       torch
0    llama3-8B      1   32   8     4096     4096     128  363.091576  285.330544
1   llama3-70B      1   64   8     4096     4096     128  385.504352  285.391074
2  llama3-405B      1  128   8     4096     4096     128  392.309230  306.246575
(py_3.10) xinyazha@1949bd361800:~/triton-rocm/python/perf-kernels$ python og-flash-attention.py -model all
fused-attention-fwd-layoutbhsd:
         model  BATCH   HQ  HK  N_CTX_Q  N_CTX_K  D_HEAD      triton       torch
0    llama3-8B      1   32   8     4096     4096     128  361.328059  284.975282
1   llama3-70B      1   64   8     4096     4096     128  385.783594  285.467320
2  llama3-405B      1  128   8     4096     4096     128  390.250257  305.214373

@vgokhale here it is

xinyazhang added a commit to ROCm/aotriton that referenced this pull request Feb 4, 2025
## Major Changes

* [kernel] Backport the 2025/01/28 `main_perf` kernel
    + [kernel] Support empty L tensor pointer.
    + See ROCm/triton#716 for more details  
    + [shim] Adjust the build rules accordingly 
* [shim] Remove non-power-of-two (NPOT) head dim 72, which triggers compiler bugs on bf16
* [db] Remove `attn_fwd` table from tuning database, since the old entries are not valid anymore.
* [db] Set all entries to `num_stages=1` since `num_stages=2` constantly trigger compiler bugs
* [test] Add new head dimensions. Now categories into three groups
    + Power-of-two head dimensions
    + Optimized NPOT head dimensions
    + Prime number head dimensions to cover all gaps b/w neighboring POT+NPOT head dims.
* [shim] Add env var `AOTRITON_SKIP_LUT_CHECK` to skip LUT sanity check on certain kernels
    + As of this PR, AOTriton must be built with `AOTRITON_SKIP_LUT_CHECK=flash.attn_fwd ninja install`

## Minor Changes

* [build] Bump the version number to 0.9.0. (Should be done at the beginning of 0.9 dev)
* [API] In the API, move bias tensor to the position immediately after v tensor, matching the kernel argument order
* [shim] Add `TensorView<0>::get_null_tensor`
* [test] Change `AttentionExtraArgs` from namedtuple to dataclass for easier-to-read default values.
* [mptune] Change output json format to match kernel argument changes.
* [test] Use cpu reference when seqlen_k == 579 (used by `test_gqa` tests). GPU reference triggers segfault.
* [test] Change default value_fudge_factor to 36.0 (Should be 40.0 if considering GQA tests)
* [shim] Fix the code path when the tuning database is not available

## Know Problems

* Tuning database for `flash.attn_fwd` kernel is cleared and no plan to re-build it ATM due to immediate additional changes to the forward kernel.
@xinyazhang xinyazhang requested a review from jtang10 February 4, 2025 21:33
In the inner loop, use PERSISTENT and PERSISTENT_DYNAMIC derived from
PERSISTENT_TYPE.

This seems working better.
Test passed with:
pytest flash-attention.py -k 'test_op_fwd_bias[dtype0-False-True-True-2-4-16384-16384-128]'

If there are still concerns, we can add type annotation tl.uint64 to
strides.
@xinyazhang xinyazhang force-pushed the xinyazhang/union-fa-naming_convention branch from fd4442b to 4d2dd42 Compare February 5, 2025 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants