-
Notifications
You must be signed in to change notification settings - Fork 253
BF16 support for Quant-LLM kernel #1147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
f50d8d7
Add FP6 benchmark option to use BF16
tobiasvanderwerff a714377
Change dequant bit-shifting logic for BF16
tobiasvanderwerff 5af3b7e
Modify dequant + tensor core ops for bf16
tobiasvanderwerff 125f17c
Template progress
tobiasvanderwerff b3c3be0
Modify fpx quant logic to include bf16
tobiasvanderwerff f828763
Add tests for FP6 BF16
tobiasvanderwerff ff2c6e8
Use type punning for large exponent multiplication
tobiasvanderwerff 4304dcc
Fix some TODOs
tobiasvanderwerff 2d00a3a
Remove option to add exponent bias directly to the exponent bits
tobiasvanderwerff ceaed34
Reformat
tobiasvanderwerff b532c51
Cleanup
tobiasvanderwerff e89274b
Fix alignment
tobiasvanderwerff ac0fbe0
Remove templated input type whenever possible
tobiasvanderwerff c1dce42
Remove templated input type whenever possible 2
tobiasvanderwerff 4546c8b
Remove templated input type whenever possible 3
tobiasvanderwerff bba42cf
Less hacky way to construct a float with a large exponent
tobiasvanderwerff e66395e
rtol=1e-2 instead of 1e-3 for bfloat16 test
tobiasvanderwerff 7e9350e
Guards for SM75
tobiasvanderwerff 401559f
Remove redundant `__CUDA_ARCH` guards in host code
tobiasvanderwerff 5d52e5b
Fix consistency in checking for `CUDA_ARCH` versions
tobiasvanderwerff 398da5b
Update docs
tobiasvanderwerff d38490f
Make float bias a constexpr
tobiasvanderwerff 11ac84b
Update docs more
tobiasvanderwerff 7bd2833
Fix SM75 support
tobiasvanderwerff 69e901d
Compile guard for sm<75
tobiasvanderwerff 8747d6d
Check for CUDA synchronous errors after kernel launch
tobiasvanderwerff 59f5eb7
Updated compile guard
tobiasvanderwerff c96cf18
Fix problematic usage of `__CUDA_ARCH__`
tobiasvanderwerff 379bd5e
Fix incorrect CUDA error handling
tobiasvanderwerff a6de35a
Make the kernel fail for sm75 + bfloat16 inputs
tobiasvanderwerff File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# FP6-LLM kernel | ||
|
||
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN). | ||
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN). | ||
|
||
On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. | ||
|
||
See https://github.com/pytorch/ao/pull/223 for some benchmark results. | ||
See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious. I saw that generally when BF16 is used, tolerance is quite higher than FP16. From your experience working on this, you do suspect any part of the code might result in this loss of precision? e.g. perhaps some parts are computed in BF16 instead of FP32. Or maybe it's just the way it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All I know is that BF16 has fewer bits for the fraction (mantissa) than FP16 (10 bits vs. 7 bits), so that leads to lower precision for BF16 compared to FP16. I can't think of any part of the FP6 kernel that would inherently lead to more loss of precision for BF16.