-
Notifications
You must be signed in to change notification settings - Fork 316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make float32_qk_product and float32_logits apply during inference #1225
Merged
copybara-service
merged 3 commits into
AI-Hypercomputer:main
from
Essential-AI:fp32_attention
Feb 11, 2025
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you have to set precision as well for float32 to actually take effect https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision - we have this option in maxtext
maxtext/MaxText/configs/base.yml
Line 79 in d33821f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point which I hadn't considered. I suppose there are two moving parts here: are the flops performed in bf16 or fp32, and are the results accumulated in bf16 or fp32. I think these are generally controlled with the
precision
andpreferred_element_type
arguments, respectively.It appears that on default precision the flops happen in bf16 even when one or more of the inputs are in fp32. However, the accumulation can still happen in fp32, and that seems to have been enough to solve our particular problem. In particular, the compiler seems to recognize that even though the python says to upcast to fp32, it can elide that because it's going to do the computation. However, it still outputs fp32.
This is the qk product with
![Screenshot 2025-01-31 at 4 53 15 PM](https://private-user-images.githubusercontent.com/169196560/408763699-37f16531-6b8e-4ebd-9fc6-6463110fa900.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkzOTcxMTgsIm5iZiI6MTczOTM5NjgxOCwicGF0aCI6Ii8xNjkxOTY1NjAvNDA4NzYzNjk5LTM3ZjE2NTMxLTZiOGUtNGViZC05ZmM2LTY0NjMxMTBmYTkwMC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjEyJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxMlQyMTQ2NThaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0wZmQyZjNlNWM5ZGNjY2UzYTRjNjEyMmFhODlmZmU4MzRkOTE3YjI4NjVkYzAzNzg5MzZjYTQ1OGEyZTU3ODZkJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.iVfb6LPbaGLH3VKNksmE1PR3IFht6pN7LdH8JVh-gfc)
float32_qk_product=False
And this is with
![Screenshot 2025-01-31 at 3 14 11 PM](https://private-user-images.githubusercontent.com/169196560/408763725-a2eafa4e-1880-4ed4-879e-78c4ce3b0908.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkzOTcxMTgsIm5iZiI6MTczOTM5NjgxOCwicGF0aCI6Ii8xNjkxOTY1NjAvNDA4NzYzNzI1LWEyZWFmYTRlLTE4ODAtNGVkNC04NzllLTc4YzRjZTNiMDkwOC5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjEyJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxMlQyMTQ2NThaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0zOGM1ZjIyOWViM2FhNDk4ZDQxY2E2MzMzNWM1YjUyODdiNzIxMmExM2E5NzFiYmNmOTc1ZWQ1YjY3ZDcxZjA1JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.7EUedM3QuOr9dIQdG1Ov1jKSKWuaaJMTLBsy363qWyo)
float32_qk_product=True
(note the output type is now f32)I'm not 100% confident in my interpretation of those graphs, but this would explain why it takes longer even without changing the precision parameter.
Separately, it looks like
matmul_precision
consistently gets routed intoDenseGeneral
usages, but not into the raw einsums used inqk_product
andwv_product
. When I changematmul_precision
in the config it does not affect the runtime of those operations, but if I add it explicitly to the einsums then thewv_product
does take longer, which makes sense. Is that something we should fix just by adding those arguments to the einsums?