Skip to content

Use userbuffers for MXFP8 wgrad all-gather overlap #1982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 9, 2025
1 change: 1 addition & 0 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def run_fwd_bwd(model, x):
if opts.use_cuda_graphs:
del test_graph

torch.cuda.synchronize()
te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)

Expand Down
25 changes: 25 additions & 0 deletions transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ CommOverlapCore::~CommOverlapCore() {
cudaStreamDestroy(_stream_compute[i]);
}

auto error = cudaGetLastError();
if (error != cudaSuccess) {
NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error));
}

if (_comm_created) {
try {
#ifdef NVTE_UB_WITH_MPI
Expand Down Expand Up @@ -289,6 +294,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType

CommOverlapBase::~CommOverlapBase() {
cudaEventDestroy(_start_d2dcopy);
cudaStreamSynchronize(_stream_comm);
cudaStreamDestroy(_stream_comm);
}

Expand Down Expand Up @@ -591,6 +597,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::split_overlap_rs

void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) {
int comm_bytes = _ubuf.bytes();
int comm_bytes_per_rank = comm_bytes / _tp_size;

// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
send_stream);
userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
recv_stream);

for (auto stream : {send_stream, recv_stream}) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
// We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0));
}
}

/***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2535,6 +2535,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
}
}

void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
for (int j = 1; j < tp_size; j++) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * tp_rank;
int recv_offset = dstoffset + bytes_per_slice * tp_rank;
userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
}
}

void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
for (int j = tp_size - 1; j > 0; j--) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * i;
int recv_offset = dstoffset + bytes_per_slice * i;
userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
}
}

// producer
static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) {
// Decrement atomic val to signal current output tile finish
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp

void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream);

void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);

void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);

#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ enum class CommOverlapAlgo {
SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7
ATOMIC_GEMM_RS_P2P = 7,
EXTERNAL_BULK_OVERLAP_AG = 8,
};

class CommOverlapCore {
Expand Down Expand Up @@ -133,6 +134,11 @@ class CommOverlapCore {
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}

virtual void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
}; // CommOverlapCore

class CommOverlapBase : public CommOverlapCore {
Expand Down Expand Up @@ -198,6 +204,9 @@ class CommOverlapBase : public CommOverlapCore {
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override;

void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) override;
}; // CommOverlapBase

class CommOverlapP2PBase : public CommOverlapCore {
Expand Down Expand Up @@ -277,6 +286,15 @@ class CommOverlapP2PBase : public CommOverlapCore {
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override;

/*
** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object.
** The gemm for overlap_gemm is assumed to have been previously started.
*/
void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
}; // CommOverlapP2PBase

} // namespace transformer_engine
Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \
.value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \
.value("EXTERNAL_BULK_OVERLAP_AG", \
transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \
py::class_<transformer_engine::CommOverlapCore, \
std::shared_ptr<transformer_engine::CommOverlapCore>>(m, "CommOverlapCore", \
pybind11::module_local()) \
Expand Down
15 changes: 13 additions & 2 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

#include "common.h"

class CommOverlapHelper;
class CommOverlap;
class CommOverlapP2P;

namespace transformer_engine::pytorch {

/***************************************************************************************************
Expand Down Expand Up @@ -419,6 +423,13 @@ void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_k

void nvshmem_finalize();

/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/

void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream,
at::Stream recv_stream);

} // namespace transformer_engine::pytorch

/***************************************************************************************************
Expand Down Expand Up @@ -468,7 +479,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);

at::Stream get_communication_stream();
std::pair<at::Stream, at::Stream> get_communication_stream();

}; // CommOverlap

Expand All @@ -489,7 +500,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);

at::Stream get_communication_stream();
std::pair<at::Stream, at::Stream> get_communication_stream();

}; // CommOverlapP2P

Expand Down
18 changes: 14 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}

at::Stream CommOverlap::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device());
std::pair<at::Stream, at::Stream> CommOverlap::get_communication_stream() {
// Return the same stream for both send and recv
return {at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()),
at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device())};
}

/***************************************************************************************************
Expand Down Expand Up @@ -305,6 +307,14 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}

at::Stream CommOverlapP2P::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device());
std::pair<at::Stream, at::Stream> CommOverlapP2P::get_communication_stream() {
return {at::cuda::getStreamFromExternal(_stream_send[0], at::cuda::current_device()),
at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())};
}

void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm(
CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) {
auto main_stream = at::cuda::getCurrentCUDAStream();
allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream),
at::cuda::CUDAStream(recv_stream), main_stream);
}
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());

// Comm+GEMM Overlap
m.def("bulk_overlap_ag_with_external_gemm",
&transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm,
"Bulk overlap All-Gather with a GEMM operation launched by another communicator",
py::call_guard<py::gil_scoped_release>(), py::arg("allgather_communicator"),
py::arg("send_stream"), py::arg("recv_stream"));

// Data structures
py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
Expand Down
29 changes: 25 additions & 4 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def initialize_ub(
```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_dgrad"]`.
"fc2_fprop", "fc2_wgrad"]`.
bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are
Expand Down Expand Up @@ -249,22 +249,31 @@ def initialize_ub(
"qkv_fprop",
"qkv_dgrad",
"proj_dgrad",
"proj_wgrad",
"fc1_fprop",
"fc1_dgrad",
"fc2_dgrad",
"fc2_wgrad",
]
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
# Default overlap methods for layers
methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"ring_exchange": [
"qkv_fprop",
"fc1_fprop",
"proj_dgrad",
"fc2_dgrad",
],
"pipeline": ["proj_fprop", "fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
}

# AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"}
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []

Expand Down Expand Up @@ -318,7 +327,7 @@ def add_ub(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
if method == "bulk":
if method in ("bulk", "external"):
warnings.warn(
f"At {name}, atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`."
Expand Down Expand Up @@ -347,6 +356,16 @@ def add_ub(
if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

if name in external_gemm_to_overlap:
assert method == "external", (
f"At {name}, `external` overlap method is specified, but the selected method is"
f" {method}"
)
assert external_gemm_to_overlap[name] in methods["ring_exchange"], (
f"At {name}, `external` overlap method is specified, but the external gemm"
f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
)

buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
if method == "ring_exchange":
ub_obj = tex.CommOverlapP2P(
Expand Down Expand Up @@ -395,7 +414,9 @@ def add_ub(
new_method = ub_cfgs[name]["method"]
methods[new_method].append(name)

for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
for name in (
methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
):
ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs:
fp8_buf = (name in layers_all_gather_overlap) or (
Expand Down
33 changes: 21 additions & 12 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,27 +755,36 @@ def backward(
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
# overlapping the AG operation with the dgrad GEMM.

# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()

# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")

ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)

# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, mxfp8_grad_output_work = gather_along_first_dim(
# We use the send stream to copy into the userbuffers.
# This is the same stream that we will use to access the data in the AG,
# so we dont need to add any syncs yet.
with torch.cuda.stream(dgrad_send_stream):
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to do the quantization here fused with the rowwise quantization above. I couldn't figure out a nice (quick) way of doing this though so I might come back to it.
Simply quantizing grad_outputs[0] at the top of the function before using it in dgrad, didn't quite work because it hit a bunch of asserts. This is hopefully something we can come back to.
As it is the quantization and NCCL all gather are also overlapped with the GEMM by virtue of running on the dgrad_send_stream here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be easily possible.

Toward the top of the backward pass, we set grad_output_quantizer.set_usage(rowwise=True, columnwise=True) before TransformerEngineBaseModule.grad_output_preprocess(), but then flip columnwise=False if we see that ub_overlap_ag is enabled.

This conditional ub_overlap_ag change to the quantizer usage needs to be shifted to after grad_output_preprocess() in order to ensure that the rowwise and columnwise quantization happen at the same time in the beginning of the backward pass.

I don't believe any further changes are necessary because the _fill_userbuffers_buffer_for_all_gather() already avoids re-quantization if it sees that the tensor has been quantized ahead of time. It will simply copy the columnwise-quantized gradient into the communication buffer based on the quantizer usage at the time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the code in the preprocess doesn't support quantized tensors (the reshape() operation at the top doesnt exist) and fill_userbuffers_buffer_for_all_gather (called in the preprocess) doesn't support both being set.

These should be fixable but I had some weird errors so I left it as it was for now

Copy link
Collaborator

@denera denera Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What goes into preprocess is always plain Torch tensors, but what comes out is a QuantizedTensor. That's as expected. So you do not want to explicitly invoke quantization here. That's what preprocess is supposed to do for you based on the usage information you set into the quantizer object.

And on that note, if the grad_output_quantizer usage is set to rowwise=True and columnwise=True before the preprocess, and everything else is passed in as usual, the preprocess should produce a single QuantizedTensor gradient that has both rowwise and columnwise quantizations in it.

At that point, grad_output_quantizer usage can be updated again to rowwise-only for DGRAD AG->GEMM and then column wise-only right before invoking _fill_userbuffers_for_all_gather() to account for the fact that only one usage can be enabled at a time in either case. This usage information will allow both the GEMM and the UB fill to pluck out the correct quantized data and scale from the same QuantizedTensor.

I will test this later tonight if I get the chance but I'd be surprised if this doesn't work out of the box. The code, in my reading, appears to account for this kind of use case already.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fusing the casts is messy since it involves outputting into multiple UB buffers. Perhaps we could modify fill_userbuffers_buffer_for_all_gather to have separate UB comms for the row-wise and column-wise data. Alternatively, we could separate the logic for constructing the quantized tensor and doing the quantize:

grad_output = make_tensor_with_userbuffers_buffers(grad_outputs[0].size(), rowwise_comm=ub_obj_dgrad, columnwise_comm=ub_obj_wgrad)
grad_output.copy_(grad_outputs[0])

ub_obj_overlap_wgrad,
grad_outputs[0],
ctx.grad_output_quantizer,
ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
)
# Synchronize with the main stream
mxfp8_grad_output_work.wait()

# Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)

# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
Expand Down
Loading