Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 backward #119

Draft
wants to merge 14 commits into
base: main_perf
Choose a base branch
from
Draft

fp8 backward #119

wants to merge 14 commits into from

Conversation

micmelesse
Copy link
Collaborator

@micmelesse micmelesse commented Jan 24, 2025

add fp8 backward

@micmelesse micmelesse changed the title add backward test case fp8 backward Jan 24, 2025
@micmelesse micmelesse force-pushed the micmelesse/fp8_bwd branch 4 times, most recently from 6b691eb to 297742b Compare February 3, 2025 09:24
@micmelesse micmelesse marked this pull request as ready for review February 4, 2025 13:37
Copy link

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm approving the PR because I can't see anything wrong with it. I just left some questions and cleanup suggestions.

flash_attn/flash_attn_triton_amd/README.md Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/bwd_prefill.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/bwd_prefill.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/bwd_prefill.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/test.py Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/test.py Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/test.py Outdated Show resolved Hide resolved
Enable BWD fp8 with per block scale factors for p
and ds

This is a combination of 9 commits.

Enable BWD fp8

This is a combination of 12 commits.

add backward test case

save clean up

disable ci

lse is good

dv matches

reduce diff

use do fp8 for dv

kinda working

group size is a constexpr

clean up a bit

everything except mqa/gqa works

skip mqa cases

20 cases have nan on dropout

save what you have

disable tests

failing

enable tests

per block descale_p and descale_ds

use max(abs(())

clean up tests a bit more

fix bug

disable ci for now

pass variables

add flags

add alternate path. Still need to load descale factors

dv working

dk works

save
@micmelesse micmelesse marked this pull request as draft February 7, 2025 19:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants