-
Notifications
You must be signed in to change notification settings - Fork 475
Use userbuffers for MXFP8 wgrad all-gather overlap #1982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c3cf103
to
7d7ce6b
Compare
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
Outdated
Show resolved
Hide resolved
7d7ce6b
to
21cbd36
Compare
# This is the same stream that we will use to access the data in the AG, | ||
# so we dont need to add any syncs yet. | ||
with torch.cuda.stream(dgrad_send_stream): | ||
grad_output, _ = fill_userbuffers_buffer_for_all_gather( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to do the quantization here fused with the rowwise quantization above. I couldn't figure out a nice (quick) way of doing this though so I might come back to it.
Simply quantizing grad_outputs[0]
at the top of the function before using it in dgrad, didn't quite work because it hit a bunch of asserts. This is hopefully something we can come back to.
As it is the quantization and NCCL all gather are also overlapped with the GEMM by virtue of running on the dgrad_send_stream
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be easily possible.
Toward the top of the backward pass, we set grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
before TransformerEngineBaseModule.grad_output_preprocess()
, but then flip columnwise=False
if we see that ub_overlap_ag
is enabled.
This conditional ub_overlap_ag
change to the quantizer usage needs to be shifted to after grad_output_preprocess()
in order to ensure that the rowwise and columnwise quantization happen at the same time in the beginning of the backward pass.
I don't believe any further changes are necessary because the _fill_userbuffers_buffer_for_all_gather()
already avoids re-quantization if it sees that the tensor has been quantized ahead of time. It will simply copy the columnwise-quantized gradient into the communication buffer based on the quantizer usage at the time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately the code in the preprocess doesn't support quantized tensors (the reshape()
operation at the top doesnt exist) and fill_userbuffers_buffer_for_all_gather
(called in the preprocess) doesn't support both being set.
These should be fixable but I had some weird errors so I left it as it was for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What goes into preprocess is always plain Torch tensors, but what comes out is a QuantizedTensor. That's as expected. So you do not want to explicitly invoke quantization here. That's what preprocess is supposed to do for you based on the usage information you set into the quantizer object.
And on that note, if the grad_output_quantizer
usage is set to rowwise=True
and columnwise=True
before the preprocess, and everything else is passed in as usual, the preprocess should produce a single QuantizedTensor gradient that has both rowwise and columnwise quantizations in it.
At that point, grad_output_quantizer
usage can be updated again to rowwise-only for DGRAD AG->GEMM and then column wise-only right before invoking _fill_userbuffers_for_all_gather()
to account for the fact that only one usage can be enabled at a time in either case. This usage information will allow both the GEMM and the UB fill to pluck out the correct quantized data and scale from the same QuantizedTensor.
I will test this later tonight if I get the chance but I'd be surprised if this doesn't work out of the box. The code, in my reading, appears to account for this kind of use case already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fusing the casts is messy since it involves outputting into multiple UB buffers. Perhaps we could modify fill_userbuffers_buffer_for_all_gather
to have separate UB comms for the row-wise and column-wise data. Alternatively, we could separate the logic for constructing the quantized tensor and doing the quantize:
grad_output = make_tensor_with_userbuffers_buffers(grad_outputs[0].size(), rowwise_comm=ub_obj_dgrad, columnwise_comm=ub_obj_wgrad)
grad_output.copy_(grad_outputs[0])
c42f6bd
to
80af14c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, pending the suggestion about the quantizer usage changes to ensure rowwise and columnwise quantization happens at the same time at the top of the backward pass.
# This is the same stream that we will use to access the data in the AG, | ||
# so we dont need to add any syncs yet. | ||
with torch.cuda.stream(dgrad_send_stream): | ||
grad_output, _ = fill_userbuffers_buffer_for_all_gather( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be easily possible.
Toward the top of the backward pass, we set grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
before TransformerEngineBaseModule.grad_output_preprocess()
, but then flip columnwise=False
if we see that ub_overlap_ag
is enabled.
This conditional ub_overlap_ag
change to the quantizer usage needs to be shifted to after grad_output_preprocess()
in order to ensure that the rowwise and columnwise quantization happen at the same time in the beginning of the backward pass.
I don't believe any further changes are necessary because the _fill_userbuffers_buffer_for_all_gather()
already avoids re-quantization if it sees that the tensor has been quantized ahead of time. It will simply copy the columnwise-quantized gradient into the communication buffer based on the quantizer usage at the time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, barring some linter warnings.
# This is the same stream that we will use to access the data in the AG, | ||
# so we dont need to add any syncs yet. | ||
with torch.cuda.stream(dgrad_send_stream): | ||
grad_output, _ = fill_userbuffers_buffer_for_all_gather( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fusing the casts is messy since it involves outputting into multiple UB buffers. Perhaps we could modify fill_userbuffers_buffer_for_all_gather
to have separate UB comms for the row-wise and column-wise data. Alternatively, we could separate the logic for constructing the quantized tensor and doing the quantize:
grad_output = make_tensor_with_userbuffers_buffers(grad_outputs[0].size(), rowwise_comm=ub_obj_dgrad, columnwise_comm=ub_obj_wgrad)
grad_output.copy_(grad_outputs[0])
/te-ci pytorch L1 |
…VIDIA#1979) Signed-off-by: djns99 <[email protected]>
Add support for overlapping column-wise allgather communication with GEMM operations to improve training performance: * **Core infrastructure changes:** - Update bulk_overlap_columnwise_ag() to accept explicit stream parameter - Modify userbuffers send/recv loops to use rank-ordered iteration - Add userbuffers_send_all/recv_all function declarations * **Python integration:** - Add bulk_overlap_ag_with_external_gemm() C++ extension function - Expose new overlap function via pybind11 bindings - Update overlap method configurations to include more ring_exchange ops * **LayerNorm MLP optimization:** - Enable column-wise quantization for FC2 gradient output - Implement overlap of allgather communication with FC2 DGRAD GEMM - Use fill_userbuffers_buffer_for_all_gather for efficient buffering This optimization allows overlapping communication and computation phases more effectively, reducing training wall-clock time by hiding allgather latency behind GEMM execution. Signed-off-by: djns99 <[email protected]>
Signed-off-by: djns99 <[email protected]>
Signed-off-by: djns99 <[email protected]>
…rmine number of copies Signed-off-by: djns99 <[email protected]>
Signed-off-by: djns99 <[email protected]>
Signed-off-by: djns99 <[email protected]>
Signed-off-by: djns99 <[email protected]>
80af14c
to
6e9f473
Compare
Signed-off-by: djns99 <[email protected]>
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. The L40 failure is spurious and the B200 failure also shows up in the nightly build.
Description
Updates the wgrad comm overlap logic to use userbuffers when doing the all gather gemm/comm overlapping
Type of change
Changes
Checklist: