-
Notifications
You must be signed in to change notification settings - Fork 481
Use userbuffers for MXFP8 wgrad all-gather overlap #1982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
276721c
fix: Add stream synchronization before destroying MPI communicator (#…
djns99 64edfeb
feat: Implement column-wise userbuffer overlap for comm+GEMM operations
djns99 83fc378
fix: Working userbuffer overlapping API
djns99 a252178
fix: Fix overwriting bulk overlap UB object for layernormLinear
djns99 b88f9c9
fix: Update external overlap to use tp size instead of nvsize to dete…
djns99 59408fb
fix: Fix linter error
djns99 11b38f0
fix: Explanatory comments of overlap logic
djns99 6e9f473
fix: Fix the UB fused ops tests
djns99 42777ef
fix: Fix linter errors
djns99 683a42c
Merge branch 'main' into djns99/column_wise_ub_overlap
timmoon10 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to do the quantization here fused with the rowwise quantization above. I couldn't figure out a nice (quick) way of doing this though so I might come back to it.
Simply quantizing
grad_outputs[0]
at the top of the function before using it in dgrad, didn't quite work because it hit a bunch of asserts. This is hopefully something we can come back to.As it is the quantization and NCCL all gather are also overlapped with the GEMM by virtue of running on the
dgrad_send_stream
here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be easily possible.
Toward the top of the backward pass, we set
grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
beforeTransformerEngineBaseModule.grad_output_preprocess()
, but then flipcolumnwise=False
if we see thatub_overlap_ag
is enabled.This conditional
ub_overlap_ag
change to the quantizer usage needs to be shifted to aftergrad_output_preprocess()
in order to ensure that the rowwise and columnwise quantization happen at the same time in the beginning of the backward pass.I don't believe any further changes are necessary because the
_fill_userbuffers_buffer_for_all_gather()
already avoids re-quantization if it sees that the tensor has been quantized ahead of time. It will simply copy the columnwise-quantized gradient into the communication buffer based on the quantizer usage at the time.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately the code in the preprocess doesn't support quantized tensors (the
reshape()
operation at the top doesnt exist) andfill_userbuffers_buffer_for_all_gather
(called in the preprocess) doesn't support both being set.These should be fixable but I had some weird errors so I left it as it was for now
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What goes into preprocess is always plain Torch tensors, but what comes out is a QuantizedTensor. That's as expected. So you do not want to explicitly invoke quantization here. That's what preprocess is supposed to do for you based on the usage information you set into the quantizer object.
And on that note, if the
grad_output_quantizer
usage is set torowwise=True
andcolumnwise=True
before the preprocess, and everything else is passed in as usual, the preprocess should produce a single QuantizedTensor gradient that has both rowwise and columnwise quantizations in it.At that point,
grad_output_quantizer
usage can be updated again to rowwise-only for DGRAD AG->GEMM and then column wise-only right before invoking_fill_userbuffers_for_all_gather()
to account for the fact that only one usage can be enabled at a time in either case. This usage information will allow both the GEMM and the UB fill to pluck out the correct quantized data and scale from the same QuantizedTensor.I will test this later tonight if I get the chance but I'd be surprised if this doesn't work out of the box. The code, in my reading, appears to account for this kind of use case already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fusing the casts is messy since it involves outputting into multiple UB buffers. Perhaps we could modify
fill_userbuffers_buffer_for_all_gather
to have separate UB comms for the row-wise and column-wise data. Alternatively, we could separate the logic for constructing the quantized tensor and doing the quantize: