Skip to content

Use userbuffers for MXFP8 wgrad all-gather overlap #1982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 9, 2025

Conversation

djns99
Copy link
Contributor

@djns99 djns99 commented Jul 22, 2025

Description

Updates the wgrad comm overlap logic to use userbuffers when doing the all gather gemm/comm overlapping

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Adds new UB manager object for proj/fc2 wgrad
  • Updates code to use new object for copies

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@djns99 djns99 force-pushed the djns99/column_wise_ub_overlap branch from c3cf103 to 7d7ce6b Compare August 1, 2025 04:43
@ptrendx ptrendx requested a review from denera August 1, 2025 23:24
@djns99 djns99 force-pushed the djns99/column_wise_ub_overlap branch from 7d7ce6b to 21cbd36 Compare August 5, 2025 05:04
# This is the same stream that we will use to access the data in the AG,
# so we dont need to add any syncs yet.
with torch.cuda.stream(dgrad_send_stream):
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to do the quantization here fused with the rowwise quantization above. I couldn't figure out a nice (quick) way of doing this though so I might come back to it.
Simply quantizing grad_outputs[0] at the top of the function before using it in dgrad, didn't quite work because it hit a bunch of asserts. This is hopefully something we can come back to.
As it is the quantization and NCCL all gather are also overlapped with the GEMM by virtue of running on the dgrad_send_stream here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be easily possible.

Toward the top of the backward pass, we set grad_output_quantizer.set_usage(rowwise=True, columnwise=True) before TransformerEngineBaseModule.grad_output_preprocess(), but then flip columnwise=False if we see that ub_overlap_ag is enabled.

This conditional ub_overlap_ag change to the quantizer usage needs to be shifted to after grad_output_preprocess() in order to ensure that the rowwise and columnwise quantization happen at the same time in the beginning of the backward pass.

I don't believe any further changes are necessary because the _fill_userbuffers_buffer_for_all_gather() already avoids re-quantization if it sees that the tensor has been quantized ahead of time. It will simply copy the columnwise-quantized gradient into the communication buffer based on the quantizer usage at the time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the code in the preprocess doesn't support quantized tensors (the reshape() operation at the top doesnt exist) and fill_userbuffers_buffer_for_all_gather (called in the preprocess) doesn't support both being set.

These should be fixable but I had some weird errors so I left it as it was for now

Copy link
Collaborator

@denera denera Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What goes into preprocess is always plain Torch tensors, but what comes out is a QuantizedTensor. That's as expected. So you do not want to explicitly invoke quantization here. That's what preprocess is supposed to do for you based on the usage information you set into the quantizer object.

And on that note, if the grad_output_quantizer usage is set to rowwise=True and columnwise=True before the preprocess, and everything else is passed in as usual, the preprocess should produce a single QuantizedTensor gradient that has both rowwise and columnwise quantizations in it.

At that point, grad_output_quantizer usage can be updated again to rowwise-only for DGRAD AG->GEMM and then column wise-only right before invoking _fill_userbuffers_for_all_gather() to account for the fact that only one usage can be enabled at a time in either case. This usage information will allow both the GEMM and the UB fill to pluck out the correct quantized data and scale from the same QuantizedTensor.

I will test this later tonight if I get the chance but I'd be surprised if this doesn't work out of the box. The code, in my reading, appears to account for this kind of use case already.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fusing the casts is messy since it involves outputting into multiple UB buffers. Perhaps we could modify fill_userbuffers_buffer_for_all_gather to have separate UB comms for the row-wise and column-wise data. Alternatively, we could separate the logic for constructing the quantized tensor and doing the quantize:

grad_output = make_tensor_with_userbuffers_buffers(grad_outputs[0].size(), rowwise_comm=ub_obj_dgrad, columnwise_comm=ub_obj_wgrad)
grad_output.copy_(grad_outputs[0])

@djns99 djns99 marked this pull request as ready for review August 6, 2025 00:13
@djns99 djns99 force-pushed the djns99/column_wise_ub_overlap branch from c42f6bd to 80af14c Compare August 6, 2025 03:47
Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, pending the suggestion about the quantizer usage changes to ensure rowwise and columnwise quantization happens at the same time at the top of the backward pass.

# This is the same stream that we will use to access the data in the AG,
# so we dont need to add any syncs yet.
with torch.cuda.stream(dgrad_send_stream):
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be easily possible.

Toward the top of the backward pass, we set grad_output_quantizer.set_usage(rowwise=True, columnwise=True) before TransformerEngineBaseModule.grad_output_preprocess(), but then flip columnwise=False if we see that ub_overlap_ag is enabled.

This conditional ub_overlap_ag change to the quantizer usage needs to be shifted to after grad_output_preprocess() in order to ensure that the rowwise and columnwise quantization happen at the same time in the beginning of the backward pass.

I don't believe any further changes are necessary because the _fill_userbuffers_buffer_for_all_gather() already avoids re-quantization if it sees that the tensor has been quantized ahead of time. It will simply copy the columnwise-quantized gradient into the communication buffer based on the quantizer usage at the time.

timmoon10
timmoon10 previously approved these changes Aug 6, 2025
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, barring some linter warnings.

# This is the same stream that we will use to access the data in the AG,
# so we dont need to add any syncs yet.
with torch.cuda.stream(dgrad_send_stream):
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fusing the casts is messy since it involves outputting into multiple UB buffers. Perhaps we could modify fill_userbuffers_buffer_for_all_gather to have separate UB comms for the row-wise and column-wise data. Alternatively, we could separate the logic for constructing the quantized tensor and doing the quantize:

grad_output = make_tensor_with_userbuffers_buffers(grad_outputs[0].size(), rowwise_comm=ub_obj_dgrad, columnwise_comm=ub_obj_wgrad)
grad_output.copy_(grad_outputs[0])

@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

djns99 added 8 commits August 7, 2025 14:52
Add support for overlapping column-wise allgather communication with GEMM
operations to improve training performance:

* **Core infrastructure changes:**
  - Update bulk_overlap_columnwise_ag() to accept explicit stream parameter
  - Modify userbuffers send/recv loops to use rank-ordered iteration
  - Add userbuffers_send_all/recv_all function declarations

* **Python integration:**
  - Add bulk_overlap_ag_with_external_gemm() C++ extension function
  - Expose new overlap function via pybind11 bindings
  - Update overlap method configurations to include more ring_exchange ops

* **LayerNorm MLP optimization:**
  - Enable column-wise quantization for FC2 gradient output
  - Implement overlap of allgather communication with FC2 DGRAD GEMM
  - Use fill_userbuffers_buffer_for_all_gather for efficient buffering

This optimization allows overlapping communication and computation phases
more effectively, reducing training wall-clock time by hiding allgather
latency behind GEMM execution.

Signed-off-by: djns99 <[email protected]>
@djns99 djns99 force-pushed the djns99/column_wise_ub_overlap branch from 80af14c to 6e9f473 Compare August 7, 2025 02:52
@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The L40 failure is spurious and the B200 failure also shows up in the nightly build.

@timmoon10 timmoon10 merged commit 077e26c into NVIDIA:main Aug 9, 2025
26 of 29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants