Skip to content

Ubnext #2038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Ubnext #2038

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 311 additions & 0 deletions tests/pytorch/distributed/test_linear_comms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import torch.distributed._symmetric_memory as symm_mem
import time
import argparse
import os
import uuid
import math


def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(
description=(
"Run a linear layer with Transformer Engine, CUDA Graphs, and Tensor Parallelism"
)
)
parser.add_argument("--in_features", type=int, default=8192, help="Input feature size")
parser.add_argument("--out_features", type=int, default=8192, help="Output feature size")
parser.add_argument("--batch_size", type=int, default=2048, help="Batch size")
parser.add_argument(
"--cuda_graph", action="store_true", help="Use CUDA Graphs (pass this flag to enable)"
)
parser.add_argument("--validate", action="store_true", help="Validate allreduce ubnext")
parser.add_argument("--comm_type", type=str, default="sym", help="Comm type: nccl,sym,ub")
parser.add_argument(
"--sym_type",
type=str,
default="multimem_all_reduce",
help="sym type: one_shot, two_shot, multimem_all_reduce, ub_custom",
)
parser.add_argument("--iterations", type=int, default=1000, help="Number of iterations")
parser.add_argument(
"--tp_size",
type=int,
default=None,
help="Tensor parallelism size (defaults to number of GPUs)",
)
args = parser.parse_args()

# Check CUDA availability and get device count
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Test requires NVIDIA GPUs.")

num_devices = torch.cuda.device_count()
if num_devices == 0:
raise RuntimeError("No CUDA devices found.")

# Set tensor parallelism size
tp_size = (
args.tp_size if args.tp_size is not None else int(os.environ.get("WORLD_SIZE", num_devices))
)

# Initialize distributed environment for each GPU
myrank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_size = int(os.environ.get("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
num_nodes = world_size // local_size
if num_nodes > 1:
assert (
"MASTER_ADDR" in os.environ
), "Multi-node run requires MASTER_ADDR to be set in the environment."
# Set device
device = torch.device(f"cuda:{local_rank}")
# Initialize torch.distributed for tensor parallelism
# Only set defaults if not already set by torchrun
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "29500"
torch.cuda.set_device(device)

torch.distributed.init_process_group(
backend="nccl", world_size=tp_size, rank=myrank % tp_size, device_id=device
)
torch.distributed.barrier(group=torch.distributed.group.WORLD)
# Transformer Engine handles tensor parallelism internally when distributed is initialized
# Set environment variable for tensor parallelism size
os.environ["NVTE_TP_SIZE"] = str(tp_size)

ub_cfgs = {
"proj_fprop": {
"method": "pipeline",
"num_splits": 1,
"is_reduce_scatter": True,
"num_sm": 32,
"atomic_gemm": False,
"aggregate": False,
"cga_size": 4,
"set_sm_margin": False,
"fp8_buf": False,
"use_ce": False,
}
}

# Initialize model with BF16 precision

modelseq = te.Linear(
in_features=int(args.in_features / tp_size),
out_features=args.out_features,
bias=False,
device=device,
params_dtype=torch.bfloat16,
)

if (
args.comm_type == "sym" and os.environ.get("NVTE_USE_UB_FOR_UBNEXT")
) or args.comm_type == "ub":
te.module.base.initialize_ub(
[args.batch_size, args.out_features],
tp_size,
use_fp8=False,
dtype=torch.bfloat16,
bootstrap_backend="nccl",
ub_cfgs=ub_cfgs,
)

modelpar = None

if args.comm_type == "sym" or args.comm_type == "nccl":
modelpar = te.Linear(
in_features=args.in_features,
out_features=args.out_features,
bias=False,
device=device,
params_dtype=torch.bfloat16,
tp_size=tp_size,
parallel_mode="row",
tp_group=torch.distributed.group.WORLD,
symmetric_ar_type=args.sym_type if args.comm_type == "sym" else None,
)

if args.comm_type == "ub":
modelpar = te.Linear(
in_features=args.in_features,
out_features=args.out_features,
bias=False,
device=device,
params_dtype=torch.bfloat16,
tp_size=tp_size,
parallel_mode="row",
tp_group=torch.distributed.group.WORLD,
sequence_parallel=True,
ub_overlap_rs=True,
ub_name="proj",
)

# Create CUDA stream
stream = torch.cuda.Stream()
# Check for environment variable to override pool size

allocator = None
if args.comm_type == "sym" and args.validate:
pool_size = int(os.environ.get("NVTE_UB_SYMM_POOL_SIZE", 64)) * 1024 * 1024
allocator = te.cpp_extensions.symm_allocator.SymmAllocator(
pool_size, torch.device(device), torch.distributed.group.WORLD
)

# Run tensor comparison tests only for symmetric communication
if args.comm_type == "sym" and args.validate:

# Test different tensor sizes from 64 to 1024*1024 elements
all_max_deltas = []
all_num_different = []
all_total_elements = []
all_sizes = []

size = 64
while size <= 1024 * 1024:
# Allocate tensors
t = allocator.create_tensor((size,), dtype=torch.bfloat16)
t.fill_(0)
t += torch.randn_like(t) # Add random noise to each element
tmain = t.clone() # Create a copy since allreduce operates in-place
torch.distributed.all_reduce(tmain)
tlamport = allocator.allreduce_lamport(t)

# Compare the two tensors
abs_diff = torch.abs(tlamport - tmain)
max_delta = torch.max(abs_diff).item()
num_different = torch.sum(tlamport != tmain).item()

# Store statistics
all_max_deltas.append(max_delta)
all_num_different.append(num_different)
all_total_elements.append(tlamport.numel())
all_sizes.append(size)

# Free tensor (memory returned to pool)
del t, tlamport, tmain, abs_diff

# Double the size for next iteration
size *= 2

# Print summary statistics
if myrank == 0:
print("\n=== Tensor Comparison Summary ===")
total_elements_tested = sum(all_total_elements)
total_different_elements = sum(all_num_different)
overall_max_delta = max(all_max_deltas)

print(
f"Tested sizes: {len(all_sizes)} different tensor sizes from {all_sizes[0]} to"
f" {all_sizes[-1]} elements"
)
print(f"Total elements tested: {total_elements_tested}")
print(f"Total different elements: {total_different_elements}")
print(
"Overall error rate:"
f" {(total_different_elements / total_elements_tested) * 100:.6f}%"
)
print(f"Maximum delta across all tests: {overall_max_delta}")

if total_different_elements > 0 or overall_max_delta > 0:
print("\nPer-size breakdown:")
for i, size in enumerate(all_sizes):
error_rate = (all_num_different[i] / all_total_elements[i]) * 100
print(
f" Size {size:7d}:"
f" {all_num_different[i]:6d}/{all_total_elements[i]:7d} different"
f" ({error_rate:6.3f}%), max_delta: {all_max_deltas[i]:.6f}"
)
print("================================\n")

torch.distributed.barrier(group=torch.distributed.group.WORLD)
torch.cuda.synchronize()

for logbatch in range(int(math.log2(args.batch_size)) + 1):
batch = 2**logbatch
if args.comm_type == "ub" and batch < tp_size:
batch = tp_size
# Create input tensor
inp = torch.randn(
batch, int(args.in_features / tp_size), device=device, dtype=torch.bfloat16
)
# Warm-up run
modelseq(inp)
modelpar(inp)
torch.cuda.synchronize()
if args.cuda_graph:
with torch.cuda.stream(stream):
# Create CUDA Graph
gseq = torch.cuda.CUDAGraph()
gpar = torch.cuda.CUDAGraph()
with torch.cuda.graph(gseq):
output = modelseq(inp)
with torch.cuda.graph(gpar):
output = modelpar(inp)
# Warm-up the graph
for _ in range(5):
gseq.replay()
gpar.replay()
torch.cuda.synchronize()

torch.distributed.barrier(group=torch.distributed.group.WORLD)
torch.distributed.barrier(group=torch.distributed.group.WORLD)
torch.cuda.synchronize()

# Measure time for forward passes
start_time = time.time()
with torch.cuda.stream(stream):
for _ in range(args.iterations):
if args.cuda_graph:
gseq.replay()
else:
modelseq(inp)

torch.cuda.synchronize()
end_time = time.time()
seq_elapsed = end_time - start_time

torch.distributed.barrier(group=torch.distributed.group.WORLD)
torch.distributed.barrier(group=torch.distributed.group.WORLD)
torch.cuda.synchronize()

# Measure time for forward passes
start_time = time.time()
with torch.cuda.stream(stream):
for _ in range(args.iterations):
if args.cuda_graph:
gpar.replay()
else:
modelpar(inp)

torch.cuda.synchronize()
end_time = time.time()
par_elapsed = end_time - start_time
nccl_elapsed = par_elapsed - seq_elapsed
# Calculate and print elapsed time (only on rank 0)
if myrank == 0:
print(
f"Batch{batch},{(seq_elapsed/ args.iterations) * 1e6:.4f}us,{(par_elapsed/ args.iterations) * 1e6:.4f} us,{(nccl_elapsed/ args.iterations) * 1e6:.4f}"
)
if args.cuda_graph:
# needed or NCCL would hang
del gseq, gpar

# Cleanup
torch.distributed.destroy_process_group()


if __name__ == "__main__":
# Generate a unique run ID for distributed initialization
os.environ["RUN_ID"] = str(uuid.uuid4())
main()
3 changes: 2 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ list(APPEND transformer_engine_SOURCES
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp)
comm_gemm_overlap/comm_gemm_overlap.cpp
ubnext.cu)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
_ub_comm->cga_size = _cga_size;
size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0);
size_t n = B.size(0);
size_t m_chunk = m / _num_splits;
const std::vector<size_t> input_a_chunk_shape =
(transa ? std::vector<size_t>{m_chunk, k} : std::vector<size_t>{k, m_chunk});
Expand Down Expand Up @@ -591,6 +591,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::split_overlap_rs

uintptr_t CommOverlapBase::init_ubnext() {
NVTE_CHECK_CUDA(cudaMemset(_ub_comm->mem_ptr[_ub_reg], 0, _ub_comm->mem_size[_ub_reg]));
NVTE_CHECK_CUDA(cudaMemcpy(_ub_comm->mem_ptr[_ub_reg],
(reinterpret_cast<char *>(_ub_comm->mem_ptr[0])) +
(_ub_reg * _ub_comm->nvsize * sizeof(void *)),
_ub_comm->nvsize * sizeof(void *), cudaMemcpyDeviceToDevice));
return (uintptr_t)(_ub_comm->mc_ptr[_ub_reg]);
}

/***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class CommOverlapBase : public CommOverlapCore {
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override;
// initialize ubnext buffer and return multicast pointer for allreduce
uintptr_t init_ubnext();
}; // CommOverlapBase

class CommOverlapP2PBase : public CommOverlapCore {
Expand Down
31 changes: 31 additions & 0 deletions transformer_engine/common/include/transformer_engine/ubnext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#ifndef TRANSFORMER_ENGINE_UBNEXT_H_
#define TRANSFORMER_ENGINE_UBNEXT_H_

#include "transformer_engine.h"

namespace transformer_engine {

#ifdef __cplusplus
extern "C" {
#endif

void allreduce_2shot_mc(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* mcptr_in,
void* mcptr_out, size_t bytes, cudaStream_t stream);
void allreduce_2shot_mc_lamport(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* ucptr_out,
void* mcptr_in, void* mcptr_out, void* clear_ptr, size_t bytes,
bool poisoned, cudaStream_t stream);
void allreduce_2shot_uc(int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out,
size_t bytes, cudaStream_t stream);

#ifdef __cplusplus
}
#endif
} // namespace transformer_engine

#endif
3 changes: 2 additions & 1 deletion transformer_engine/common/libtransformer_engine.version
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
*transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore*;
*nvshmem_wait_on_stream*;
*nvshmemi_init_thread*
*nvshmemi_init_thread*;
allreduce_*;
};
local: *;
};
Loading