From 2b1fad217b3e287cb9369f90f033d87543d27fe7 Mon Sep 17 00:00:00 2001 From: Anton Korzh Date: Wed, 6 Aug 2025 13:47:50 -0700 Subject: [PATCH 1/9] UBNext Allreduce integration Signed-off-by: Anton Korzh --- .../pytorch/distributed/test_linear_comms.py | 265 ++++++++++++ transformer_engine/common/CMakeLists.txt | 3 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 11 +- .../transformer_engine/comm_gemm_overlap.h | 2 + .../include/transformer_engine/ubnext.h | 28 ++ .../common/libtransformer_engine.version | 3 +- transformer_engine/common/ubnext.cu | 398 ++++++++++++++++++ .../common/util/pybind_helper.h | 34 +- .../pytorch/cpp_extensions/__init__.py | 1 + .../pytorch/cpp_extensions/symm_allocator.py | 289 +++++++++++++ transformer_engine/pytorch/module/base.py | 9 +- transformer_engine/pytorch/module/linear.py | 16 +- 12 files changed, 1050 insertions(+), 9 deletions(-) create mode 100644 tests/pytorch/distributed/test_linear_comms.py create mode 100644 transformer_engine/common/include/transformer_engine/ubnext.h create mode 100644 transformer_engine/common/ubnext.cu create mode 100644 transformer_engine/pytorch/cpp_extensions/symm_allocator.py diff --git a/tests/pytorch/distributed/test_linear_comms.py b/tests/pytorch/distributed/test_linear_comms.py new file mode 100644 index 0000000000..414d6c200b --- /dev/null +++ b/tests/pytorch/distributed/test_linear_comms.py @@ -0,0 +1,265 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +import torch.distributed._symmetric_memory as symm_mem +import time +import argparse +import os +import uuid +import math + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser(description="Run a linear layer with Transformer Engine, CUDA Graphs, and Tensor Parallelism") + parser.add_argument('--in_features', type=int, default=8192, help='Input feature size') + parser.add_argument('--out_features', type=int, default=8192, help='Output feature size') + parser.add_argument('--batch_size', type=int, default=2048, help='Batch size') + parser.add_argument('--cuda_graph', action='store_true', help='Use CUDA Graphs (pass this flag to enable)') + parser.add_argument('--validate', action='store_true', help='Validate allreduce ubnext') + parser.add_argument('--comm_type', type=str, default="sym", help='Comm type: nccl,sym,ub') + parser.add_argument('--sym_type', type=str, default="multimem_all_reduce", help='sym type: one_shot, two_shot, multimem_all_reduce, ub_custom') + parser.add_argument('--iterations', type=int, default=1000, help='Number of iterations') + parser.add_argument('--tp_size', type=int, default=None, help='Tensor parallelism size (defaults to number of GPUs)') + args = parser.parse_args() + + # Check CUDA availability and get device count + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Test requires NVIDIA GPUs.") + + num_devices = torch.cuda.device_count() + if num_devices == 0: + raise RuntimeError("No CUDA devices found.") + + # Set tensor parallelism size + tp_size = args.tp_size if args.tp_size is not None else int(os.environ.get('WORLD_SIZE', num_devices)) + + # Initialize distributed environment for each GPU + myrank = int(os.environ.get('RANK', 0)) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + local_size = int(os.environ.get('LOCAL_WORLD_SIZE', str(torch.cuda.device_count()))) + num_nodes = world_size // local_size + if num_nodes > 1: + assert ("MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + # Set device + device = torch.device(f"cuda:{local_rank}") + # Initialize torch.distributed for tensor parallelism + # Only set defaults if not already set by torchrun + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = 'localhost' + if 'MASTER_PORT' not in os.environ: + os.environ['MASTER_PORT'] = '29500' + torch.cuda.set_device(device) + + torch.distributed.init_process_group( + backend='nccl', + world_size=tp_size, + rank=myrank % tp_size, + device_id=device + ) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + # Transformer Engine handles tensor parallelism internally when distributed is initialized + # Set environment variable for tensor parallelism size + os.environ['NVTE_TP_SIZE'] = str(tp_size) + + ub_cfgs = { + "proj_fprop": {"method": "pipeline","num_splits":1,"is_reduce_scatter":True,"num_sm":32,"atomic_gemm":False,"aggregate":False, + "cga_size":4,"set_sm_margin":False,"fp8_buf":False,"use_ce":False} + } + + # Initialize model with BF16 precision + + modelseq = te.Linear( + in_features=int(args.in_features/tp_size), + out_features=args.out_features, + bias=False, + device=device, + params_dtype=torch.bfloat16 + ) + + if (args.comm_type=='sym' and os.environ.get("NVTE_USE_UB_FOR_UBNEXT")) or args.comm_type=='ub': + te.module.base.initialize_ub( + [args.batch_size,args.out_features], + tp_size, + use_fp8=False, + dtype=torch.bfloat16, + bootstrap_backend="nccl", + ub_cfgs=ub_cfgs + ) + + modelpar = None + + if args.comm_type=='sym' or args.comm_type=='nccl' : + modelpar = te.Linear( + in_features=args.in_features, + out_features=args.out_features, + bias=False, + device=device, + params_dtype=torch.bfloat16, + tp_size = tp_size, + parallel_mode="row", + tp_group=torch.distributed.group.WORLD, + symmetric_ar_type=args.sym_type if args.comm_type=='sym' else None + ) + + if(args.comm_type=='ub'): + modelpar = te.Linear( + in_features=args.in_features, + out_features=args.out_features, + bias=False, + device=device, + params_dtype=torch.bfloat16, + tp_size = tp_size, + parallel_mode="row", + tp_group=torch.distributed.group.WORLD, + sequence_parallel=True, + ub_overlap_rs=True, + ub_name="proj" + ) + + # Create CUDA stream + stream = torch.cuda.Stream() + # Check for environment variable to override pool size + + allocator = None + if args.comm_type == "sym" and args.validate: + pool_size = int(os.environ.get('NVTE_UB_SYMM_POOL_SIZE', 64)) * 1024 * 1024 + allocator = te.cpp_extensions.symm_allocator.SymmAllocator(pool_size, torch.device(device),torch.distributed.group.WORLD) + + # Run tensor comparison tests only for symmetric communication + if args.comm_type == "sym" and args.validate: + + # Test different tensor sizes from 64 to 1024*1024 elements + all_max_deltas = [] + all_num_different = [] + all_total_elements = [] + all_sizes = [] + + size = 64 + while size <= 1024 * 1024: + # Allocate tensors + t = allocator.create_tensor((size,), dtype=torch.bfloat16) + t.fill_(0) + t += torch.randn_like(t) # Add random noise to each element + tmain = t.clone() # Create a copy since allreduce operates in-place + torch.distributed.all_reduce(tmain) + tlamport = allocator.allreduce_lamport(t) + + # Compare the two tensors + abs_diff = torch.abs(tlamport - tmain) + max_delta = torch.max(abs_diff).item() + num_different = torch.sum(tlamport != tmain).item() + + # Store statistics + all_max_deltas.append(max_delta) + all_num_different.append(num_different) + all_total_elements.append(tlamport.numel()) + all_sizes.append(size) + + # Free tensor (memory returned to pool) + del t, tlamport, tmain, abs_diff + + # Double the size for next iteration + size *= 2 + + # Print summary statistics + if myrank == 0: + print("\n=== Tensor Comparison Summary ===") + total_elements_tested = sum(all_total_elements) + total_different_elements = sum(all_num_different) + overall_max_delta = max(all_max_deltas) + + print(f"Tested sizes: {len(all_sizes)} different tensor sizes from {all_sizes[0]} to {all_sizes[-1]} elements") + print(f"Total elements tested: {total_elements_tested}") + print(f"Total different elements: {total_different_elements}") + print(f"Overall error rate: {(total_different_elements / total_elements_tested) * 100:.6f}%") + print(f"Maximum delta across all tests: {overall_max_delta}") + + if total_different_elements > 0 or overall_max_delta > 0: + print("\nPer-size breakdown:") + for i, size in enumerate(all_sizes): + error_rate = (all_num_different[i] / all_total_elements[i]) * 100 + print(f" Size {size:7d}: {all_num_different[i]:6d}/{all_total_elements[i]:7d} different ({error_rate:6.3f}%), max_delta: {all_max_deltas[i]:.6f}") + print("================================\n") + + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + for logbatch in range(int(math.log2(args.batch_size)) + 1): + batch = 2**logbatch + if args.comm_type=='ub' and batch < tp_size: batch = tp_size + # Create input tensor + inp = torch.randn(batch, int(args.in_features/tp_size), device=device, dtype=torch.bfloat16) + # Warm-up run + modelseq(inp) + modelpar(inp) + torch.cuda.synchronize() + if args.cuda_graph: + with torch.cuda.stream(stream): + # Create CUDA Graph + gseq = torch.cuda.CUDAGraph() + gpar = torch.cuda.CUDAGraph() + with torch.cuda.graph(gseq): + output = modelseq(inp) + with torch.cuda.graph(gpar): + output = modelpar(inp) + # Warm-up the graph + for _ in range(5): + gseq.replay() + gpar.replay() + torch.cuda.synchronize() + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + # Measure time for forward passes + start_time = time.time() + with torch.cuda.stream(stream): + for _ in range(args.iterations): + if args.cuda_graph: + gseq.replay() + else: + modelseq(inp) + + torch.cuda.synchronize() + end_time = time.time() + seq_elapsed = end_time - start_time + + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.distributed.barrier(group=torch.distributed.group.WORLD) + torch.cuda.synchronize() + + # Measure time for forward passes + start_time = time.time() + with torch.cuda.stream(stream): + for _ in range(args.iterations): + if args.cuda_graph: + gpar.replay() + else: + modelpar(inp) + + torch.cuda.synchronize() + end_time = time.time() + par_elapsed = end_time - start_time + nccl_elapsed = (par_elapsed-seq_elapsed) + # Calculate and print elapsed time (only on rank 0) + if myrank == 0: + print(f"Batch{batch},{(seq_elapsed/ args.iterations) * 1e6:.4f}us,{(par_elapsed/ args.iterations) * 1e6:.4f} us,{(nccl_elapsed/ args.iterations) * 1e6:.4f}") + if args.cuda_graph: + # needed or NCCL would hang + del gseq, gpar + + # Cleanup + torch.distributed.destroy_process_group() + +if __name__ == "__main__": + # Generate a unique run ID for distributed initialization + os.environ['RUN_ID'] = str(uuid.uuid4()) + main() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b51e61929b..6d4538d4eb 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -109,7 +109,8 @@ list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) + comm_gemm_overlap/comm_gemm_overlap.cpp + ubnext.cu) add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 38a6e3e61d..2c054201a6 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -462,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons _ub_comm->cga_size = _cga_size; size_t m = transa ? A.size(0) : A.size(1); size_t k = transa ? A.size(1) : A.size(0); - size_t n = _ubuf.size(0); + size_t n = B.size(0); size_t m_chunk = m / _num_splits; const std::vector input_a_chunk_shape = (transa ? std::vector{m_chunk, k} : std::vector{k, m_chunk}); @@ -591,6 +591,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); } // CommOverlapBase::split_overlap_rs +uintptr_t CommOverlapBase::init_ubnext() { + NVTE_CHECK_CUDA(cudaMemset(_ub_comm->mem_ptr[_ub_reg], 0, _ub_comm->mem_size[_ub_reg])); + NVTE_CHECK_CUDA(cudaMemcpy(_ub_comm->mem_ptr[_ub_reg], + (reinterpret_cast(_ub_comm->mem_ptr[0])) + + (_ub_reg * _ub_comm->nvsize * sizeof(void *)), + _ub_comm->nvsize * sizeof(void *), cudaMemcpyDeviceToDevice)); + return (uintptr_t)(_ub_comm->mc_ptr[_ub_reg]); +} + /*************************************************************************************************** * Comm+GEMM Overlap P2P Base (Ring-Exchange) **************************************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 293c57526d..e1120e8b10 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -198,6 +198,8 @@ class CommOverlapBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override; + // initialize ubnext buffer and return multicast pointer for allreduce + uintptr_t init_ubnext(); }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { diff --git a/transformer_engine/common/include/transformer_engine/ubnext.h b/transformer_engine/common/include/transformer_engine/ubnext.h new file mode 100644 index 0000000000..d0d2b8c2ac --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/ubnext.h @@ -0,0 +1,28 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + + #ifndef TRANSFORMER_ENGINE_UBNEXT_H_ + #define TRANSFORMER_ENGINE_UBNEXT_H_ + + #include "transformer_engine.h" + + namespace transformer_engine { + + #ifdef __cplusplus + extern "C" { + #endif + + void allreduce_2shot_mc(int ranks, int myrank, void* uc0ptr,void* mc0ptr, void* mcptr_in,void* mcptr_out, size_t bytes,cudaStream_t stream); + void allreduce_2shot_mc_lamport(int ranks, int myrank,void* uc0ptr,void* mc0ptr,void* ucptr_out, void* mcptr_in, + void* mcptr_out,void* clear_ptr,size_t bytes,bool poisoned,cudaStream_t stream); + void allreduce_2shot_uc(int ranks, int myrank,void* uc0ptr,void* ucptr_in, void* ucptr_out,size_t bytes,cudaStream_t stream); + + #ifdef __cplusplus + } + #endif +} + +#endif \ No newline at end of file diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 706c237ccc..68e5f8aef8 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -19,7 +19,8 @@ *transformer_engine::CommOverlapP2PBase*; *transformer_engine::CommOverlapCore*; *nvshmem_wait_on_stream*; - *nvshmemi_init_thread* + *nvshmemi_init_thread*; + allreduce_*; }; local: *; }; diff --git a/transformer_engine/common/ubnext.cu b/transformer_engine/common/ubnext.cu new file mode 100644 index 0000000000..15b80f51e4 --- /dev/null +++ b/transformer_engine/common/ubnext.cu @@ -0,0 +1,398 @@ +#include +#include +#include + +#include "./common.h" + +#define TIMEOUT 2000000000ull +//#define UB_TIMEOUT_ENABLED 1 + +#define NVTE_UB_MAXTHREADS 1024 +#define NVTE_UB_MAX_SMS 128 +#define NVTE_UB_LAMPORT_INT 0xFFFAFFFA + +//REG0 flags in use +#define NVTE_UB_FLAG_NVLS2_LAMPORT_ID 0 +#define NVTE_UB_FLAG_NVLS2_LAMPORT_SM_SYNC 1 +#define NVTE_UB_FLAG_NVLS2_LAMPORT_RS_BAR 2 +#define NVTE_UB_FLAG_NVLS2_ID 3 +#define NVTE_UB_FLAG_NVLS2_SM_SYNC 4 +#define NVTE_UB_FLAG_NVLS2_RS_BAR 5 +#define NVTE_UB_FLAG_NVLS2_AG_BAR 6 + +#define xhalf __nv_bfloat16 + +#define ATOMIC_MCINC(ptr) \ + asm volatile("multimem.red.add.u32 [%0], %1;" ::"l"(ptr), "r"(1) \ + : "memor" \ + "y"); +#define ATOMIC_UCINC(ptr) \ + asm volatile("red.global.add.u32 [%0], %1;" ::"l"(ptr), "r"(1) \ + : "memor" \ + "y"); +#define MULTIMEM_ST(val, ptr) \ + asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), \ + "r"(val.y), "r"(val.z), "r"(val.w) \ + : "memory"); + +#define MULTIMEM_LD(val, ptr) \ + asm("multimem.ld_reduce.global.add.v4.bf16x2.acc::f32 {%0,%1,%2,%3}, [%4];" \ + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) \ + : "l"(ptr) \ + : "memory"); + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// Return true if producer > consumer, otherwise false while preventing integer overflow +// If we expect that producer will be 2B+ messages behind consumer +#define CHECK_IDS(producer, consumer) (((unsigned)(producer) - (unsigned)(consumer)) & (~INT_MAX)) + +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) + userbuffers_fp16_sum_inplace_gpu_mc(const int RANKS, const int myrank, const int numlines, + int *uc_flagptr, int *mc_flagptr, float4 *mc_ptr_in, + float4 *mc_ptr_out) { + // flags[3,4,5,6]: reduce_id, sm_sync-local, flag-barrier-1,flag-barrier-2 + int reduce_id; + + if (threadIdx.x == 0) { + cudaGridDependencySynchronize(); + if (blockIdx.x == 0) ATOMIC_MCINC(mc_flagptr + NVTE_UB_FLAG_NVLS2_RS_BAR); + + reduce_id = uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] + 1; + + volatile int *flag = (volatile int *)&(uc_flagptr[NVTE_UB_FLAG_NVLS2_RS_BAR]); + + const int expected = reduce_id * RANKS; + +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } + } + + __syncthreads(); +#define UNROLL_MC 4 + const int loop_step0 = blockDim.x * gridDim.x * RANKS; + const int loop_step = loop_step0 * UNROLL_MC; + const int start_elem = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); + const int end_elem = max(start_elem, numlines); + const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step; + const int end_aligned = start_elem + aligned_elem; + + for (int line = start_elem; line < end_aligned; line += loop_step) { + uint4 val[UNROLL_MC]; +#pragma unroll + for (int i = 0; i < UNROLL_MC; i++) MULTIMEM_LD(val[i], mc_ptr_in + (line + i * loop_step0)) +#pragma unroll + for (int i = 0; i < UNROLL_MC; i++) MULTIMEM_ST(val[i], mc_ptr_out + (line + i * loop_step0)) + } + for (int line = end_aligned; line < end_elem; line += loop_step0) { + uint4 val; + MULTIMEM_LD(val, mc_ptr_in + (line)) + MULTIMEM_ST(val, mc_ptr_out + (line)) + } + + __syncthreads(); + if (threadIdx.x != 0) return; + + __threadfence(); + const int value_to_add = blockIdx.x == 0 ? NVTE_UB_MAX_SMS - gridDim.x + 1 : 1; + const int old_val_sm_sync = atomicAdd(uc_flagptr + NVTE_UB_FLAG_NVLS2_SM_SYNC, value_to_add); + + const int lastSM = + (gridDim.x == 1 || old_val_sm_sync + value_to_add == reduce_id * NVTE_UB_MAX_SMS); + if (!lastSM) return; + __threadfence_system(); + ATOMIC_MCINC(mc_flagptr + NVTE_UB_FLAG_NVLS2_AG_BAR); + uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] = reduce_id; + cudaTriggerProgrammaticLaunchCompletion(); + volatile int *flag = (volatile int *)&(uc_flagptr[NVTE_UB_FLAG_NVLS2_AG_BAR]); + const int expected = reduce_id * RANKS; + +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } +} // fp16 inplace reduce kernel (Hopper) MC + +template +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) + userbuffers_fp16_sum_inplace_gpu_uc(const int myrank, const int numlines, + const int lineoffset_in, const int lineoffset_out, + int *uc_flagptr, void **commbuff) { + // flags[3,4,5,6]: reduce_id, sm_sync-local, flag-barrier-1,flag-barrier-2 + //NB! uc_flagptr is shifted by ranks*8 for easier flag offsets + // while lineoffset is relative to start of reg0 + __shared__ int4 *userptr[RANKS]; + __shared__ int lastSM; + int reduce_id; + + if (threadIdx.x < RANKS) { + int *rem_flagptr = (reinterpret_cast(commbuff[threadIdx.x])); + cudaGridDependencySynchronize(); + if (blockIdx.x == 0) ATOMIC_UCINC(rem_flagptr + NVTE_UB_FLAG_NVLS2_RS_BAR + RANKS * 2); + + reduce_id = uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] + 1; + + userptr[threadIdx.x] = (int4 *)rem_flagptr; + } + + if (threadIdx.x == 0) { + volatile int *flag = uc_flagptr + NVTE_UB_FLAG_NVLS2_RS_BAR; + lastSM = 0; + const int expected = reduce_id * RANKS; +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } + } + + __syncthreads(); + + int warp = blockIdx.x + (threadIdx.x >> 5); + int dest[RANKS]; +#pragma unroll + for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1); + + __syncthreads(); + for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines; + line += blockDim.x * gridDim.x * RANKS) { + int4 val[RANKS]; + +#pragma unroll + for (int i = 0; i < RANKS; i++) { + // int dest = (i+myrank+warp)&(RANKS-1); + val[i] = userptr[dest[i]][lineoffset_in + line]; + } + + int4 sum = val[0]; + xhalf *s = reinterpret_cast(&sum); + +#pragma unroll + for (int i = 1; i < RANKS; i++) { + xhalf *x = reinterpret_cast(&val[i]); +#pragma unroll + for (int j = 0; j < 8; j++) s[j] += x[j]; + } +#pragma unroll + for (int i = 0; i < RANKS; i++) { + // int dest = (i+myrank+warp)&(RANKS-1); + userptr[dest[i]][lineoffset_out + line] = sum; + } + } + + __syncthreads(); + + if (threadIdx.x == 0) { + __threadfence(); + const int value_to_add = blockIdx.x == 0 ? NVTE_UB_MAX_SMS - gridDim.x + 1 : 1; + const int old_val_sm_sync = atomicAdd(uc_flagptr + NVTE_UB_FLAG_NVLS2_SM_SYNC, value_to_add); + lastSM = (gridDim.x == 1 || old_val_sm_sync + value_to_add == reduce_id * NVTE_UB_MAX_SMS); + if (lastSM) uc_flagptr[NVTE_UB_FLAG_NVLS2_ID] = reduce_id; + cudaTriggerProgrammaticLaunchCompletion(); + } + if (threadIdx.x >= RANKS) return; + __syncthreads(); + if (!lastSM) return; + if (threadIdx.x == 0) __threadfence_system(); + __syncthreads(); + ATOMIC_UCINC((int *)(userptr[threadIdx.x]) + NVTE_UB_FLAG_NVLS2_AG_BAR + RANKS * 2); + if (threadIdx.x != 0) return; + volatile int *flag = uc_flagptr + NVTE_UB_FLAG_NVLS2_AG_BAR; + const int expected = reduce_id * RANKS; +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } +} // UC 2shot kernel (non-lamport) + +__global__ void memset_int(uint32_t *data, int n, uint32_t val) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + data[idx] = val; + } +} + +__global__ void __launch_bounds__(NVTE_UB_MAXTHREADS) userbuffers_fp16_sum_inplace_gpu_mc_lamport( + const int RANKS, const int myrank, const int numlines, int *uc_flagptr, int *mc_flagptr, + float4 *mc_ptr_in, float4 *mc_ptr_out, uint4 *uc_ptr_out, uint4 *clear_ptr) { + // flags[0,1,2]: reduce_id, sm_sync-local, flag-barrier + // those go right after rank UC pointers, but its the CPU caller who should account for it + int reduce_id; + + if (threadIdx.x == 0) { + cudaGridDependencySynchronize(); + if (blockIdx.x == 0) ATOMIC_MCINC(mc_flagptr + NVTE_UB_FLAG_NVLS2_LAMPORT_RS_BAR); + reduce_id = uc_flagptr[NVTE_UB_FLAG_NVLS2_LAMPORT_ID]; + const int value_to_add = blockIdx.x == 0 ? NVTE_UB_MAX_SMS - gridDim.x + 1 : 1; + const int old_val_sm_sync = + atomicAdd(uc_flagptr + NVTE_UB_FLAG_NVLS2_LAMPORT_SM_SYNC, value_to_add); + volatile int *flag = (volatile int *)&(uc_flagptr[NVTE_UB_FLAG_NVLS2_LAMPORT_RS_BAR]); + reduce_id++; + const int lastSM = + (gridDim.x == 1 || old_val_sm_sync + value_to_add == reduce_id * NVTE_UB_MAX_SMS); + + if (lastSM) uc_flagptr[NVTE_UB_FLAG_NVLS2_LAMPORT_ID] = reduce_id; + cudaTriggerProgrammaticLaunchCompletion(); + + const int expected = reduce_id * RANKS; + +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (CHECK_IDS(*flag, expected)) { +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, expected, + *flag); + break; + } +#endif + } + } + __syncthreads(); + + const int loop_step0 = blockDim.x * gridDim.x * RANKS; + const int start_elem = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); + + for (int line = start_elem; line < numlines; line += loop_step0) { + uint4 val; + MULTIMEM_LD(val, mc_ptr_in + (line)) + MULTIMEM_ST(val, mc_ptr_out + (line)) + } + + for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < numlines; + line += blockDim.x * gridDim.x) { +#ifdef UB_TIMEOUT_ENABLED + clock_t s = clock64(); +#endif + while (true) { + uint4 result; + + asm volatile("ld.volatile.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(result.x), "=r"(result.y), "=r"(result.z), "=r"(result.w) + : "l"(&uc_ptr_out[line]) + : "memory"); + if (result.w != NVTE_UB_LAMPORT_INT) { + if (clear_ptr) clear_ptr[line].w = NVTE_UB_LAMPORT_INT; + break; + } +#ifdef UB_TIMEOUT_ENABLED + if (clock64() - s > TIMEOUT) { + printf("Lamport POLL:SM %d [%d]:expecting %d got (%d,%d,%d) %d\n", blockIdx.x, threadIdx.x, + NVTE_UB_LAMPORT_INT, result.x, result.y, result.z, result.w); + break; + } +#endif + } + } + +} // two-shot NVLS + lamport sync instead of last membar + +#define SETUP_LAUNCH_CONFIG(sms, threads, stream, cga_size, pdl_launch) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[3]; \ + attribute_ub[2].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[2].val.clusterDim.x = sms % cga_size == 0 ? cga_size : 1; \ + attribute_ub[2].val.clusterDim.y = 1; \ + attribute_ub[2].val.clusterDim.z = 1; \ + attribute_ub[1].id = cudaLaunchAttributeCooperative; \ + attribute_ub[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ + attribute_ub[0].val.programmaticStreamSerializationAllowed = pdl_launch; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = 3; + +namespace transformer_engine { + +extern "C" void allreduce_2shot_mc(int ranks, int myrank, void *uc0ptr, void *mc0ptr, + void *mcptr_in, void *mcptr_out, size_t bytes, + cudaStream_t stream) { + SETUP_LAUNCH_CONFIG(32, 1024, stream, 4, 1); + + int arg1 = ranks, arg2 = myrank, arg3 = bytes / 16; + void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in, + *arg7 = mcptr_out; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg4, + (void *)&arg5, (void *)&arg6, (void *)&arg7}; + CUDACHECK(cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc), kernelArgs)); +} + +extern "C" void allreduce_2shot_uc(int ranks, int myrank, void *uc0ptr, void *ucptr_in, + void *ucptr_out, size_t bytes, cudaStream_t stream) { + SETUP_LAUNCH_CONFIG(64, 1024, stream, 4, 1); + + int arg1 = myrank, arg2 = bytes / 16, arg3 = (int4 *)ucptr_in - (int4 *)uc0ptr, + arg4 = (int4 *)ucptr_out - (int4 *)uc0ptr; + void *arg5 = uc0ptr + (ranks * 8), **arg6 = (void **)uc0ptr; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, + (void *)&arg4, (void *)&arg5, (void *)&arg6}; +#define call_uc_kernel(x) \ + if (x == ranks) \ + CUDACHECK( \ + cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_uc), kernelArgs)); + call_uc_kernel(2); + call_uc_kernel(4); + call_uc_kernel(8); +} + +extern "C" void allreduce_2shot_mc_lamport(int ranks, int myrank, void *uc0ptr, void *mc0ptr, + void *ucptr_out, void *mcptr_in, void *mcptr_out, + void *clear_ptr, size_t bytes, bool poisoned, + cudaStream_t stream) { + if (!poisoned) { + //user tells us destination was not pre-poisoned, so we need to do it before calling allreduce + int threadsPerBlock = 512; + int blocks = (bytes / 4 + threadsPerBlock - 1) / threadsPerBlock; + memset_int<<>>((uint32_t *)ucptr_out, bytes / 4, + NVTE_UB_LAMPORT_INT); + } + SETUP_LAUNCH_CONFIG(64, 1024, stream, 4, 1); + + int arg1 = ranks, arg2 = myrank, arg3 = bytes / 16; + void *arg4 = uc0ptr + (ranks * 8), *arg5 = mc0ptr + (ranks * 8), *arg6 = mcptr_in, + *arg7 = mcptr_out, *arg8 = ucptr_out, *arg9 = clear_ptr; + void *kernelArgs[] = {(void *)&arg1, (void *)&arg2, (void *)&arg3, (void *)&arg4, (void *)&arg5, + (void *)&arg6, (void *)&arg7, (void *)&arg8, (void *)&arg9}; + CUDACHECK( + cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc_lamport), kernelArgs)); +} +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index a1cd85ba2a..efcca73222 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -11,7 +11,7 @@ #include #include #include - +#include #include "cuda_runtime.h" #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ @@ -110,6 +110,8 @@ std::shared_ptr, \ transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + py::call_guard()) \ + .def("init_ubnext", &transformer_engine::CommOverlapBase::init_ubnext, \ py::call_guard()); \ py::class_, \ @@ -128,6 +130,34 @@ }, \ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); + py::call_guard()); \ + m.def( \ + "allreduce_2shot_mc", \ + [](int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* mcptr_in, void* mcptr_out, \ + size_t bytes) { \ + transformer_engine::allreduce_2shot_mc(ranks, myrank, uc0ptr, mc0ptr, mcptr_in, mcptr_out, \ + bytes, at::cuda::getCurrentCUDAStream()); \ + }, \ + py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("mc0ptr"), \ + py::arg("mcptr_in"), py::arg("mcptr_out"), py::arg("bytes")); \ + m.def( \ + "allreduce_2shot_uc", \ + [](int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out, size_t bytes) { \ + transformer_engine::allreduce_2shot_uc(ranks, myrank, uc0ptr, ucptr_in, ucptr_out, bytes, \ + at::cuda::getCurrentCUDAStream()); \ + }, \ + py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("ucptr_in"), \ + py::arg("ucptr_out"), py::arg("bytes")); \ + m.def( \ + "allreduce_2shot_mc_lamport", \ + [](int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* ucptr_out, void* mcptr_in, \ + void* mcptr_out, void* clear_ptr, size_t bytes, bool poisoned) { \ + transformer_engine::allreduce_2shot_mc_lamport( \ + ranks, myrank, uc0ptr, mc0ptr, ucptr_out, mcptr_in, mcptr_out, clear_ptr, bytes, \ + poisoned, at::cuda::getCurrentCUDAStream()); \ + }, \ + py::arg("ranks"), py::arg("myrank"), py::arg("uc0ptr"), py::arg("mc0ptr"), \ + py::arg("ucptr_out"), py::arg("mcptr_in"), py::arg("mcptr_out"), py::arg("clear_ptr"), \ + py::arg("bytes"), py::arg("poisoned")); #endif diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index 944d1849bf..075eea0158 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -7,3 +7,4 @@ from .fused_attn import * from .gemm import * +from .symm_allocator import * \ No newline at end of file diff --git a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py new file mode 100644 index 0000000000..06062b3453 --- /dev/null +++ b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py @@ -0,0 +1,289 @@ +import torch +import os +import gc +import weakref +from typing import List, Tuple, Optional, Dict +from threading import Lock +import torch.distributed._symmetric_memory as symm_mem +from ctypes import pythonapi, c_void_p, py_object + +def to_capsule(ptr): + # Set the return type to py_object to get a Python object (PyCapsule) + pythonapi.PyCapsule_New.restype = py_object + pythonapi.PyCapsule_New.argtypes = [c_void_p, c_void_p, c_void_p] + # Create capsule with a name (optional, can be None) and no destructor + capsule = pythonapi.PyCapsule_New(ptr, None, None) + return capsule + +class SymmTensor(torch.Tensor): + """Custom tensor subclass that uses custom memory""" + @staticmethod + def __new__(cls, pool: torch.Tensor, offset: int, shape: torch.Size, dtype: torch.dtype, allocator: 'SymmAllocator'): + # Calculate number of elements and bytes + num_elements = torch.Size(shape).numel() + element_size = torch.tensor(0, dtype=dtype).element_size() + nbytes = element_size * num_elements + + # Validate pool + assert pool.dtype == torch.uint8, f"Expected uint8 pool, got {pool.dtype}" + assert pool.numel() >= offset + nbytes, f"Pool too small: {pool.numel()} bytes, need {offset + nbytes}" + + # Slice the pool to get the required bytes + byte_slice = pool[offset:offset + nbytes] + + # Reinterpret the uint8 bytes as the target dtype + tensor = byte_slice.view(dtype=dtype) + tensor = tensor.view(*shape) + + # Initialize as a subclass of torch.Tensor + self = torch.Tensor._make_subclass(cls, tensor) + if not isinstance(allocator, SymmAllocator): + raise TypeError(f"Expected SymmAllocator, got {type(allocator)}") + self._allocator = allocator + self._ptr = tensor.data_ptr() + self._offset = offset + self._size = nbytes + return self + + def __del__(self): + """Custom deallocator to return memory to the pool.""" + if hasattr(self, '_allocator') and hasattr(self, '_ptr'): + self._allocator.free(self._ptr) + +class SymmAllocator: + def __init__(self, size_bytes: int, device: torch.device, dist_group: torch.distributed.group): + """Initialize the allocator with a preallocated memory pool.""" + # Preallocate the memory pool using torch.empty + self.reg0_size = 1024 # NVL72*8 plus up to 112 flags + self.device = device + self.world_size = torch.distributed.get_world_size(dist_group) + self.myrank = torch.distributed.get_rank(dist_group) + self.dist_group = dist_group + + from ..module.base import get_ub + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): + self.ub_obj = get_ub("ubnext") + self.internal_pool = self.ub_obj.get_buffer(False).reshape(-1) + self.mc0_ptr = self.ub_obj.init_ubnext() + self.pool_size = self.internal_pool.numel() + else: + alignment = 2 * 1024 * 1024 # memory is allocated in 2MB pages anyways + self.pool_size = int((size_bytes + alignment - 1) / alignment) * alignment + self.internal_pool = symm_mem.empty(self.pool_size, dtype=torch.uint8, device=device) + self.hdl0 = symm_mem.rendezvous(self.internal_pool, dist_group) + self.mc0_ptr = self.hdl0.multicast_ptr + self.internal_pool.fill_(0) + self.internal_pool.view(torch.int64)[:self.world_size].copy_(torch.tensor(self.hdl0.buffer_ptrs).view(torch.int64)) + #self.hdl0.barrier(channel=0) + # Synchronize all processes before proceeding + torch.distributed.barrier(group=dist_group) + + # Track the raw pointer to the pool + self.pool_ptr = self.internal_pool.data_ptr() + # Track allocated segments: (offset, size) + self.allocated: List[Tuple[int, int]] = [] + # Track free segments: (offset, size) + self.freelist: List[Tuple[int, int]] = [(self.reg0_size, self.pool_size-self.reg0_size)] + self.nextpoisoned = None + self.tensors = weakref.WeakSet() + self.lock = Lock() + + + def allocate(self, nbytes: int) -> Tuple[Optional[int], Optional[torch.Tensor]]: + """Allocate nbytes from the pool, returning a pointer and pool reference.""" + with self.lock: + for i, (offset, size) in enumerate(self.freelist): + if size >= nbytes: + self.freelist.pop(i) + self.allocated.append((offset, nbytes)) + if size > nbytes: + self.freelist.append((offset + nbytes, size - nbytes)) + return self.pool_ptr + offset, self.internal_pool + return None,None + + + # No suitable free segment found + raise MemoryError(f"Preallocated pool exhausted: requested {nbytes} bytes, " + f"available segments: {self.freelist}") + + def free(self, ptr: int): + """Free the memory at ptr, returning it to the pool.""" + with self.lock: + offset = ptr - self.pool_ptr + for i, (alloc_offset, size) in enumerate(self.allocated): + if alloc_offset == offset: + self.allocated.pop(i) + self.freelist.append((offset, size)) + self.freelist.sort(key=lambda x: x[0]) + self._merge_free_segments() + return + # Ignore invalid pointers silently + pass + + raise ValueError(f"Invalid pointer {ptr} not found in allocated segments") + + def _merge_free_segments(self): + """Merge adjacent free segments to reduce fragmentation.""" + if not self.freelist: + return + merged = [] + current_offset, current_size = self.freelist[0] + for offset, size in self.freelist[1:]: + if current_offset + current_size == offset: + # Adjacent segments, merge them + current_size += size + else: + # Non-adjacent, keep current and start new + merged.append((current_offset, current_size)) + current_offset, current_size = offset, size + merged.append((current_offset, current_size)) + self.freelist = merged + + def create_tensor(self, shape: Tuple[int, ...], dtype: torch.dtype = torch.float32) -> Optional[torch.Tensor]: + """Create a PooledTensor using memory from the pool.""" + nbytes = torch.tensor(0, dtype=dtype).element_size() * torch.Size(shape).numel() + ptr, pool = self.allocate(nbytes) + if ptr is None: return None + offset = ptr - self.pool_ptr + tensor = SymmTensor(pool, offset, torch.Size(shape), dtype, self) + self.tensors.add(tensor) + return tensor + + def allreduce_uc(self, tensor_in: torch.Tensor) -> torch.Tensor: + """Performs in-place allreduce on the given SymmTensor using best algo""" + assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" + + #tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + + ucptr_in = tensor_in.data_ptr() + #mcptr_out = tensor_out.data_ptr() + nbytes = tensor_in.numel() * tensor_in.element_size() + + # Import your pybind module if not imported + from transformer_engine_torch import allreduce_2shot_uc + + allreduce_2shot_uc( + self.world_size, + self.myrank, + to_capsule(self.internal_pool.data_ptr()), + to_capsule(ucptr_in), + to_capsule(ucptr_in),#out + nbytes + ) + return tensor_in + + def allreduce(self, tensor_in: torch.Tensor) -> torch.Tensor: + """Performs in-place allreduce on the given SymmTensor using best algo""" + assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" + + #tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + + mcptr_in = self.mc0_ptr + (tensor_in.data_ptr() - self.internal_pool.data_ptr()) + #mcptr_out = self.hdl.multicast_ptr + (tensor_out.data_ptr() - self.internal_pool.data_ptr()) + nbytes = tensor_in.numel() * tensor_in.element_size() + + # Import your pybind module if not imported + from transformer_engine_torch import allreduce_2shot_mc + + allreduce_2shot_mc( + self.world_size, + self.myrank, + to_capsule(self.internal_pool.data_ptr()), + to_capsule(self.mc0_ptr), + to_capsule(mcptr_in), + to_capsule(mcptr_in),#out + nbytes + ) + return tensor_in + + def allreduce_lamport(self, tensor_in: torch.Tensor) -> torch.Tensor: + """ + Performs allreduce using 2-shot multicast Lamport variant: + - Takes `tensor_in` as input (SymmTensor). + - Allocates `tensor_out` of same shape and dtype. + - Runs `allreduce_2shot_mc_lamport` over them. + - Returns `tensor_out`. + """ + assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" + if self.mc0_ptr is None or self.mc0_ptr == 0: return self.allreduce_uc(tensor_in) + from transformer_engine_torch import allreduce_2shot_mc_lamport + + # Allocate output tensor of same shape/dtype + tensor_out = self.nextpoisoned + poisonedout = True + + if self.nextpoisoned is None or self.nextpoisoned.shape!=tensor_in.shape: + if self.nextpoisoned is not None: + del self.nextpoisoned + self.nextpoisoned = None + tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + poisonedout = False + if tensor_out is None: return self.allreduce(tensor_in) + + # alllcate potential output for next allreduce (speculative) and poison it now + self.nextpoisoned = self.create_tensor(tensor_in.shape, tensor_in.dtype) + + # Calculate mcptr_in and mcptr_out with offset relative to internal_pool + offset = tensor_in.data_ptr() - self.internal_pool.data_ptr() + mcptr_in = self.mc0_ptr + offset + mcptr_out = self.mc0_ptr + (tensor_out.data_ptr() - self.internal_pool.data_ptr()) + + # Use clear_ptr to clear output memory before reduction; here we use tensor_out + #clear_ptr = self.nextpoisoned.data_ptr() if self.nextpoisoned is not None else 0 + + nbytes = tensor_in.numel() * tensor_in.element_size() + + # Call your pybind lamport allreduce + allreduce_2shot_mc_lamport( + self.world_size, + self.myrank, + to_capsule(self.internal_pool.data_ptr()), + to_capsule(self.mc0_ptr), + to_capsule(tensor_out.data_ptr()), + to_capsule(mcptr_in), + to_capsule(mcptr_out), + to_capsule(self.nextpoisoned.data_ptr()) if self.nextpoisoned is not None else None, + nbytes, + poisonedout + ) + + return tensor_out + +_allocator_map: Dict[torch.distributed.group, Tuple[int, 'SymmAllocator']] = {} + +def ubsymm_request_allocator(dist_group: torch.distributed.group, shape:Optional[Tuple[int, ...]] = None, dtype: torch.dtype = torch.bfloat16) -> None: + if shape is not None: + num_elements = torch.Size(shape).numel() + element_size = torch.tensor(0, dtype=dtype).element_size() + tensor_size = num_elements * element_size + else: + tensor_size = 0 + + if dist_group not in _allocator_map: + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): + assert _allocator_map.is_empty(), "Current UBNEXT-UB bypass supports only one process group." + _allocator_map[dist_group] = (tensor_size, None) + else: + old_size, allocator = _allocator_map[dist_group] + assert allocator is None, "Second element of tuple must be None" + max_size = max(old_size, tensor_size) + _allocator_map[dist_group] = (max_size, None) + +def ubsymm_get_sym_tensor(shape: Tuple[int, ...], dtype: torch.dtype, dist_group: torch.distributed.group) -> torch.Tensor: + if dtype != torch.bfloat16: + return None # Unsupported dtype, do fallback to nccl + if dist_group not in _allocator_map: return None # No allocator requested earlier, do fallback to nccl + (max_size, allocator) = _allocator_map[dist_group] + if allocator is None: + new_max_size = int(os.environ.get('NVTE_UB_SYMM_POOL_SIZE', ((6*max_size+1048575)/1024/1024) ) ) + allocator = SymmAllocator( + new_max_size * 1024 * 1024, + torch.device(f'cuda:{torch.cuda.current_device()}'), + dist_group + ) + _allocator_map[dist_group] = (new_max_size, allocator) + return allocator.create_tensor(shape, dtype) + +def ubsymm_allreduce(tensor_in: SymmTensor) -> SymmTensor: + return tensor_in._allocator.allreduce_lamport(tensor_in) + \ No newline at end of file diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b0da6e5fca..976afac5e0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -261,7 +261,10 @@ def initialize_ub( "pipeline": ["proj_fprop", "fc2_fprop"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], } - + # Add "ubnext" to bulk methods if environment variable is set + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): + methods["bulk"].append("ubnext") + # AG-RS overlap pairs of layers forming a tensor-parallel block ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} @@ -367,8 +370,8 @@ def add_ub( ) else: ub_obj = tex.CommOverlap( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type + shape if name != "ubnext" else (int(os.environ.get('NVTE_UB_SYMM_POOL_SIZE', 64)), 1024*1024), #Communication buffer shape + buffer_dtype if name != "ubnext" else torch.uint8, # Communication buffer data type helper, # Helper for torch.distributed callbacks during bootstrapping tp_size, # Tensor-parallel group size (may be different than local_size) num_splits=num_splits, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f2a6871a8a..4131c4e61a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -7,6 +7,7 @@ from functools import reduce from operator import mul as multiply_op import warnings +import os import torch @@ -52,6 +53,9 @@ ) from ..cpp_extensions import ( general_gemm, + ubsymm_request_allocator, + ubsymm_get_sym_tensor, + ubsymm_allreduce ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo @@ -290,6 +294,9 @@ def forward( out_shape[-1] = out_features reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) + symm_out = None + if symmetric_ar_type == 'ub_custom': + symm_out = ubsymm_get_sym_tensor( (list(inp.shape)[0], out_features,), activation_dtype, tp_group) # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T @@ -306,6 +313,7 @@ def forward( ub=ub_obj, ub_type=ub_type, extra_output=reduce_scatter_out, + out=symm_out, ) nvtx_range_pop(f"{nvtx_label}.gemm") # ------------------------------------------------------ @@ -326,7 +334,11 @@ def forward( out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + if symm_out is not None: + out = ubsymm_allreduce(symm_out) + else: + fallback_symmetric = "multimem_all_reduce" if symmetric_ar_type == "ub_custom" else symmetric_ar_type + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=fallback_symmetric) else: out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") @@ -1151,6 +1163,8 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" + if self.symmetric_ar_type == 'ub_custom': + ubsymm_request_allocator(self.tp_group, (int(os.environ.get('NVTE_UB_MAXBATCH',64)), self.out_features,), params_dtype) # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() From 115508056842a62e0a339ad3b980a1b9e2c13d8e Mon Sep 17 00:00:00 2001 From: Anton Korzh Date: Wed, 6 Aug 2025 14:22:13 -0700 Subject: [PATCH 2/9] layernorm_linear using ubnext Signed-off-by: Anton Korzh --- .../pytorch/module/layernorm_linear.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5e45b5c255..14f5bf1c3e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -8,6 +8,7 @@ from typing import Callable, Dict, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op +import os import torch from torch.nn import init @@ -73,6 +74,9 @@ from ..cpp_extensions import ( general_gemm, + ubsymm_request_allocator, + ubsymm_get_sym_tensor, + ubsymm_allreduce ) __all__ = ["LayerNormLinear"] @@ -325,7 +329,9 @@ def forward( out_shape[0] //= tp_world_size out_shape[-1] = out_features reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) - + symm_out = None + if symmetric_ar_type == 'ub_custom': + symm_out = ubsymm_get_sym_tensor( (list(inp.shape)[0], out_features,), activation_dtype, tp_group) # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T @@ -342,6 +348,7 @@ def forward( ub=ub_obj, ub_type=ub_type, extra_output=reduce_scatter_out, + out=symm_out, ) nvtx_range_pop(f"{nvtx_label}.gemm") # ------------------------------------------------------ @@ -367,7 +374,11 @@ def forward( out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: if symmetric_ar_type is not None: - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) + if symm_out is not None: + out = ubsymm_allreduce(symm_out) + else: + fallback_symmetric = "multimem_all_reduce" if symmetric_ar_type == "ub_custom" else symmetric_ar_type + out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=fallback_symmetric) else: out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") @@ -1236,7 +1247,8 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - + if self.symmetric_ar_type == 'ub_custom': + ubsymm_request_allocator(self.tp_group, (int(os.environ.get('NVTE_UB_MAXBATCH',64)), self.out_features,), params_dtype) self.eps = eps layer_norm_weight = torch.nn.Parameter( torch.empty(self.in_features, device=device, dtype=params_dtype) From 8e5ba45885abdfb4b0613b73621fe3f9a0582872 Mon Sep 17 00:00:00 2001 From: Anton Korzh Date: Wed, 6 Aug 2025 14:09:10 -0700 Subject: [PATCH 3/9] minor fix Signed-off-by: Anton Korzh --- transformer_engine/pytorch/cpp_extensions/symm_allocator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py index 06062b3453..8c10b9218a 100644 --- a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py +++ b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py @@ -261,7 +261,7 @@ def ubsymm_request_allocator(dist_group: torch.distributed.group, shape:Optional if dist_group not in _allocator_map: if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): - assert _allocator_map.is_empty(), "Current UBNEXT-UB bypass supports only one process group." + assert not _allocator_map, "Current UBNEXT-UB bypass supports only one process group." _allocator_map[dist_group] = (tensor_size, None) else: old_size, allocator = _allocator_map[dist_group] From 476efac9ff17ebf47c44cbc86783cf91b1a35c7c Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 7 Aug 2025 02:39:22 +0800 Subject: [PATCH 4/9] [PyTorch] Multi-tensor swizzle scaling factors for MXFP8 and fuse padding zeros (#2019) * for loop Signed-off-by: Xin Yao * bulk alloc Signed-off-by: Xin Yao * multi-tensor swizzle Signed-off-by: Xin Yao * pad zeros in swizzle kernels Signed-off-by: Xin Yao * unify single- and multi-tensor swizzle Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix empty tensor list Signed-off-by: Xin Yao * fix bug for col swizzle Signed-off-by: Xin Yao * check context & fix signifiers Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Anton Korzh --- benchmarks/linear/benchmark_grouped_linear.py | 2 +- transformer_engine/common/common.cu | 1 + .../include/transformer_engine/swizzle.h | 14 + transformer_engine/common/swizzle/swizzle.cu | 439 ++++++++++++++++-- transformer_engine/common/util/padding.cu | 1 + .../pytorch/csrc/extensions/cast.cpp | 10 +- .../pytorch/csrc/extensions/gemm.cpp | 19 +- transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- transformer_engine/pytorch/csrc/util.cpp | 95 ++++ transformer_engine/pytorch/csrc/util.h | 11 +- 10 files changed, 533 insertions(+), 67 deletions(-) diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 0dbee212d6..44f1c89673 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -247,7 +247,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): num_gemms_list = [8] if args.profile: - mkns = [(4096, 4096, 4096)] + mkns = [(4096 * 8, 4096, 4096)] # in profile mode, only run one recipe specified in args.recipe assert args.recipe != "all", ( "In profile mode, only one recipe can be specified, please specify the recipe as" diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 619bf6ca00..9831bbb24d 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -138,6 +138,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits) { + cuda_driver::ensure_context_exists(); // Get a function pointer to the cuTensorMapEncodeTiled driver API // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index de5a11eb73..079feb4a7d 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -30,6 +30,20 @@ extern "C" { */ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM + * + * \param[in] inputs Input tensors with non-swizzled scale_inv. + * \param[in,out] outputs Output tensors which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index cea0e5080b..37d7491d96 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -15,15 +15,17 @@ #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" +namespace transformer_engine { namespace { -constexpr int TB_DIM = 32; -constexpr int NEW_SF_TILE_DIM_K = 16; -constexpr int N_SF_PER_TD_PER_TILE = 4; +constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; +constexpr __device__ __host__ int TB_DIM = 32; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; +constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; // output is in ~K-major interleaved blocks -constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; -constexpr int NEW_SF_TILE_DIM_M_I32 = 32; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; template __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { @@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { } template -__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, - const int K) { +__device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; @@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons int m_tiles_in_tb = N_TILE_PER_TD; int k_tiles_in_tb = TB_DIM; - if (blockIdx.x == gridDim.x - 1) { + if (bid_x == grid_dim_x - 1) { k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; } - if (blockIdx.y == gridDim.y - 1) { + if (bid_y == grid_dim_y - 1) { m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; } - const int32_t* input_i32 = reinterpret_cast(input) + - blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + - blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + + const int input_offset = + bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + const int32_t* input_i32 = reinterpret_cast(input) + input_offset; int32_t* output_i32[N_TILE_PER_TD]; #pragma unroll for (int i = 0; i < m_tiles_in_tb; i++) { - output_i32[i] = reinterpret_cast(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + - (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + output_i32[i] = reinterpret_cast(output) + bid_x * TB_DIM * SF_TILE_SIZE_I32 + + (bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; } extern __shared__ int slm[]; @@ -90,8 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons threadIdx.y < k_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { - regs_vec[i] = __ldg(reinterpret_cast( - input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); + const int thread_offset = + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; + regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); + // Pad zeros + if (padding_m || padding_k) { + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (input_offset + thread_offset) * sizeof(int) + j; + if (index / M >= original_K || index % M >= original_M) { + reinterpret_cast(regs_vec + i)[j] = 0; + } + } + } } // local shuffle @@ -126,6 +144,14 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons } } +template +__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} + template __device__ inline void regs_shuffle(LType* regs_vec) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); @@ -143,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) { } template -__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, - const int K) { +__device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; @@ -154,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons int n_tiles_in_tb = N_TILES_IN_TB; const int K_i32 = K / 4; - if (blockIdx.x == gridDim.x - 1) { + if (bid_x == grid_dim_x - 1) { n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; } - const int* input_i32 = reinterpret_cast(input) + - blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; - int* output_i32 = reinterpret_cast(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + - blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + + const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; + const int* input_i32 = reinterpret_cast(input) + input_offset; + int* output_i32 = reinterpret_cast(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 + + bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32; extern __shared__ int4 slm_v4i[]; @@ -170,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { - regs_vec[i] = __ldg(reinterpret_cast( - input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); + const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; + regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); + if (padding_m || padding_k) { + // Pad zeros + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (input_offset + thread_offset) * sizeof(int) + j; + if (index / K >= original_M || index % K >= original_K) { + reinterpret_cast(regs_vec + i)[j] = 0; + } + } + } } // shuffle regs @@ -196,9 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons } } -} // namespace +template +__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} -namespace transformer_engine { +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB +struct MultiSwizzleArgs { + // (input) Data buffers for input scaling factors + void* input_list[kMaxTensorsPerKernel]; + // (output) Data buffers for swizzled scaling factors + void* output_list[kMaxTensorsPerKernel]; + // Input scaling factor m + int m_list[kMaxTensorsPerKernel]; + // Input scaling factor k + int k_list[kMaxTensorsPerKernel]; + // Input scaling factor m before padding + int original_m_list[kMaxTensorsPerKernel]; + // Input scaling factor k before padding + int original_k_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of CUDA blocks needed for each + // tensor + int block_range[kMaxTensorsPerKernel + 1]; + // Number of tensors being processed by kernel + int num_tensors; +}; + +template +__global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { + // Find tensor corresponding to block + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + // Get args corresponding to block + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + // Get block index in grid. Emulate 2D grid. + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB); + const int grid_dim_y = num_tiles_m; + const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; + const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; + + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +template +__global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) { + // Find tensor corresponding to block + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + // Get args corresponding to block + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + // Get block index in grid. Emulate 2D grid. + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM); + const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD); + const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; + const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; + + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +} // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { @@ -252,27 +383,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int n_tiles_in_tb = TB_DIM * vec_load_size; dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_first_dim(); + const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: cudaFuncSetAttribute(swizzle_row_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 2: cudaFuncSetAttribute(swizzle_row_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 1: cudaFuncSetAttribute(swizzle_row_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; default: NVTE_ERROR("Not valid vec_load_size."); @@ -285,27 +418,32 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int n_tiles_in_tb = TB_DIM * vec_load_size; dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: cudaFuncSetAttribute(swizzle_col_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; case 2: cudaFuncSetAttribute(swizzle_col_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; case 1: cudaFuncSetAttribute(swizzle_col_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; default: NVTE_ERROR("Not valid vec_load_size."); @@ -317,10 +455,212 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s } else { NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); } - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - printf("CUDA Error: %s\n", cudaGetErrorString(err)); - exit(-1); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, + const int vec_load_size, const bool is_rowwise, + cudaStream_t stream) { + int n_tiles_in_tb = TB_DIM * vec_load_size; + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + /* Calculate number of CUDA blocks needed for each tensor. + * We have to do it here because we have to iterate over all tensors in this batch to + * get the minimum vec_load_size. + */ + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int m = kernel_args.m_list[j]; + const int k = kernel_args.k_list[j]; + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + if (is_rowwise) { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; + } else { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + + DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + } + } + // Launch kernel + const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + dim3 block_size(TB_DIM, TB_DIM); + if (is_rowwise) { + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 2: + cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 1: + cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } else { + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 2: + cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 1: + cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} +void multi_tensor_swizzle_scaling_factors(const std::vector& input, + std::vector& output, cudaStream_t stream) { + auto num_tensors = input.size(); + bool all_has_data = true; + bool all_has_columnwise_data = true; + for (size_t i = 0; i < num_tensors; i++) { + if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) { + NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + "."); + } + // We don't allow empty tensors. They should be filtered out before calling this function. + if (input[i]->data.numel() == 0) { + NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty."); + } + CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); + CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); + all_has_data &= input[i]->has_data(); + all_has_columnwise_data &= input[i]->has_columnwise_data(); + } + NVTE_CHECK(all_has_data || all_has_columnwise_data, + "All tensors should have data or columnwise data."); + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + if (all_has_data) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + //Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + // Reset the argument struct and vec_load_size + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->scale_inv.shape[0]; + const int k = input[i]->scale_inv.shape[1]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK( + m * k == std::accumulate(output[i]->scale_inv.shape.begin(), + output[i]->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + // We use the minimum vec_load_size across all tensors. + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + // Launch the remaining tensors + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + } + + if (all_has_columnwise_data) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + //Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); + // Reset the argument struct and vec_load_size + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->columnwise_scale_inv.shape[1]; + const int k = input[i]->columnwise_scale_inv.shape[0]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), + output[i]->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + // We use the minimum vec_load_size across all tensors. + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = input[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + // Launch the remaining tensors + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); } } } // namespace transformer_engine @@ -335,3 +675,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud using namespace transformer_engine; swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); } + +void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + std::vector input_list, output_list; + for (size_t i = 0; i < num_tensors; i++) { + input_list.push_back(convertNVTETensorCheck(inputs[i])); + output_list.push_back(convertNVTETensorCheck(outputs[i])); + } + multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); +} diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index a1899d5b10..ad6cf2a2ee 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -35,6 +35,7 @@ struct MultiPaddingArgs { int padded_num_rows_list[kMaxTensorsPerKernel]; // Input matrix widths int row_length_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of CUDA blocks needed for each // tensor int block_range[kMaxTensorsPerKernel + 1]; // Number of tensors being processed by kernel diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5408cf1a6b..fe7aecbc22 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -398,11 +398,8 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel auto buffer = std::make_shared( - at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // auto buffer = std::make_shared( - // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -441,11 +438,8 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel auto buffer = std::make_shared( - at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // auto buffer = std::make_shared( - // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 99bb4e69fd..4f1ab3e561 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -326,10 +326,8 @@ std::optional> te_general_grouped_gemm( size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, te_pre_gelu_out_vector, te_workspace_vector; - std::vector wrappers; + std::vector te_A_wrappers, te_B_wrappers, wrappers; std::vector D_vectors; - // Keep the swizzled scaling factor tensors alive during the GEMMs. - std::vector> swizzled_scale_inverses_list; auto none = py::none(); @@ -396,10 +394,6 @@ std::optional> te_general_grouped_gemm( continue; } - // Optionally swizzle the scaling factors - swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa))); - swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb))); - auto te_D = makeTransformerEngineTensor(out_tensor); auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); @@ -419,18 +413,25 @@ std::optional> te_general_grouped_gemm( te_bias_vector.emplace_back(te_bias.data()); te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - wrappers.emplace_back(std::move(te_A)); - wrappers.emplace_back(std::move(te_B)); + te_A_wrappers.emplace_back(std::move(te_A)); + te_B_wrappers.emplace_back(std::move(te_B)); wrappers.emplace_back(std::move(te_D)); wrappers.emplace_back(std::move(te_bias)); wrappers.emplace_back(std::move(te_pre_gelu_out)); } + + // Optionally swizzle the scaling factors + // Keep the swizzled scaling factor tensors alive during the GEMMs. + auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); + auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); + for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), std::vector{workspaceSize}, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); wrappers.emplace_back(std::move(wsp)); } + // For now, we only have multi-stream cublas backend. NVTE_SCOPED_GIL_RELEASE({ nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index f0e0aba00d..fc5f99dcb9 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -841,13 +841,13 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - rowwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), columnwise_scale_inv_shape.end()); columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); } // Convert tensors to Python @@ -939,7 +939,7 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), scale_inv_shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - rowwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; } } else { // rowwise_usage == false @@ -966,7 +966,7 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), scale_inv_shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - columnwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; } } else { // columnwise_usage == false diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index a878345ffc..92f2d3a500 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -75,3 +75,98 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap return swizzled_scale_inv; } + +std::optional multi_tensor_swizzle_scaling_factors( + std::vector& tensors, bool rowwise) { + using namespace transformer_engine::pytorch; + + if (tensors.empty()) { + return std::nullopt; + } + + bool all_same_scaling_mode = std::all_of( + tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) { + return val.scaling_mode() == tensors.front().scaling_mode(); + }); + NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same."); + + if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { + NVTE_ERROR("Invalid scaling mode for swizzle."); + } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) { + return std::nullopt; + } + + std::vector wrappers; + std::vector input_tensors, output_tensors; + + // Collect scale_inv shapes and calculate buffer size and offsets for scale_invs + std::vector> scale_inv_shapes; + std::vector scale_inv_dptrs; + size_t buffer_size = 0; + std::vector scale_inv_offsets; + constexpr size_t scale_elem_size = 1; + for (auto& tensor : tensors) { + NVTEBasicTensor scale_inv; + if (rowwise) { + scale_inv = tensor.get_rowwise_scale_inv(); + } else { + scale_inv = tensor.get_columnwise_scale_inv(); + } + auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_inv_offsets.push_back(buffer_size); + buffer_size += product(scale_inv_shape) * scale_elem_size; + scale_inv_shapes.emplace_back(scale_inv_shape); + scale_inv_dptrs.push_back(scale_inv.data_ptr); + } + + // Allocate full buffer + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + + for (size_t i = 0; i < tensors.size(); ++i) { + auto& tensor = tensors[i]; + void* scale_inv_dptr = scale_inv_dptrs[i]; + void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); + auto input_shape = nvte_shape_to_vector(tensor.shape()); + + // Reconstruct input only to avoid swizzling both directions if not needed. + // Use any 8 bit type, it's irrelevant. + transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, + input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, + transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + // Set the swizzled scaling factor to the original tensor. + tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + } else { + input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, + input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + output_cu.set_columnwise_data(tensor.columnwise_dptr(), + transformer_engine::DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_scale_inv( + swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + // Set the swizzled scaling factor to the original tensor. + tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, + transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + } + + input_tensors.emplace_back(input_cu.data()); + output_tensors.emplace_back(output_cu.data()); + wrappers.emplace_back(std::move(input_cu)); + wrappers.emplace_back(std::move(output_cu)); + } + + // Launch kernel + nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(), + input_tensors.size(), at::cuda::getCurrentCUDAStream()); + + return buffer; +} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 0cfeb81f59..4b26860967 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -13,11 +13,18 @@ #include "transformer_engine/transformer_engine.h" -/* Swizzle the scaling factor of the input tensor. +/*! \brief Swizzle the scaling factor of the input tensor. * * The returned swizzled scaling factor tensor should be kept alive during the GEMM. */ std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper &input, - bool trans); + bool rowwise); + +/*! \brief Swizzle the scaling factor of the input tensors. + * + * The returned swizzled scaling factor tensors should be kept alive during the GEMMs. + */ +std::optional multi_tensor_swizzle_scaling_factors( + std::vector &inputs, bool rowwise); #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ From acd5adf3d83ffd91e90077d97bd05403d9df8d20 Mon Sep 17 00:00:00 2001 From: hx Date: Thu, 7 Aug 2025 05:18:03 +0800 Subject: [PATCH 5/9] [PyTorch] fix input_quantizer usage for save_original_input; fix blockwise FP8 convert_and_update_tensor (#1978) * fix input_quantizer in save_original_input bwd Signed-off-by: Hongxiao Bai * fix get shape of blockwise tensor with only compact colwise data Signed-off-by: Hongxiao Bai * fix blockwise FP8 convert_and_update_tensor Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Anton Korzh --- tests/pytorch/test_float8blockwisetensor.py | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 129 +++++++++++++++++- transformer_engine/pytorch/module/linear.py | 11 +- 3 files changed, 129 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 1f23be3626..39062b442b 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -219,7 +219,7 @@ def test_quantize_dequantize_compact_format( rowwise=True, columnwise=dq_columnwise, block_scaling_dim=block_scaling_dim, - all_gather_usage=True, + all_gather_usage=(block_scaling_dim == 1), ) self._test_quantize_dequantize( quantizer=quantizer, diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index fc5f99dcb9..0c75789ed9 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -671,13 +671,128 @@ std::pair Float8BlockQuantizer::convert_and_update_te const DType dtype = tensor.attr("_fp8_dtype").cast(); bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); - // Check the data matches quantizer usages - NVTE_CHECK(!tensor.attr("_rowwise_data").is_none() == rowwise_usage, - "Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=", - !tensor.attr("_rowwise_data").is_none(), ", rowwise_usage=", rowwise_usage); - NVTE_CHECK(!tensor.attr("_columnwise_data").is_none() == columnwise_usage, - "Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=", - !tensor.attr("_columnwise_data").is_none(), ", columnwise_usage=", columnwise_usage); + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + NVTE_CHECK(rowwise_data || columnwise_data, "FP8BlockwiseTensor has no data."); + + // Tensor options and dimensions + at::TensorOptions opts; + at::TensorOptions scale_opts; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); + + auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector { + if (!columnwise_data) { + return std::vector(); + } + if (all_gather_usage) { + return getTensorShape(*columnwise_data); + } + std::vector shape = getTensorShape(*columnwise_data); + std::vector shape_transposed(shape.size()); + for (size_t i = 0; i + 1 < shape.size(); ++i) { + shape_transposed[i] = shape[i + 1]; + } + if (shape.size() > 0) { + shape_transposed[shape.size() - 1] = shape[0]; + } + return shape_transposed; + }; + std::vector shape; + if (rowwise_data) { + shape = getTensorShape(*rowwise_data); + if (columnwise_data) { + auto expected_shape = get_columnwise_shape(all_gather_usage); + NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, + ") and column-wise data (shape=", expected_shape, ") do not match"); + } + } else { + shape = get_columnwise_shape(all_gather_usage); + } + std::vector torch_shape; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + } + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + rowwise_data = at::empty(torch_shape, opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + auto scale_shape = get_scale_shape(shape, false); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; + rowwise_scale_inv = + at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + std::vector columnwise_shape; + std::vector torch_columnwise_shape; + if (torch_shape.size() > 0) { + if (!all_gather_usage) { + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); + } + } else { + // assert we are doing 1D scaling + NVTE_CHECK(block_scaling_dim == 1, + "Compact columnwise format is not supported for 128x128 2D block scaling."); + torch_columnwise_shape = torch_shape; + columnwise_shape = shape; + } + } + if (!columnwise_data) { + columnwise_data = at::empty(torch_columnwise_shape, opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + auto scale_shape = get_scale_shape(shape, true); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; + columnwise_scale_inv = + at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + } auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 4131c4e61a..2562d3e7d4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -601,13 +601,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: # Quantize input tensor quantizer = ctx.input_quantizer - if ctx.backward_input_needs_gather and isinstance( - quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) - ): + if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # All-gather is not supported with FP8 column-wise data - quantizer.set_usage(rowwise=True, columnwise=False) + quantizer.set_usage( + rowwise=True, + columnwise=not ctx.backward_input_needs_gather, + ) else: - quantizer.set_usage(rowwise=True, columnwise=True) + quantizer.set_usage(rowwise=False, columnwise=True) inputmat = quantizer(inputmat) else: if isinstance(inputmat, QuantizedTensorBase): From 54f092d3e922707eed58cf11b80aea12e3f32cc0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 6 Aug 2025 17:23:55 -0400 Subject: [PATCH 6/9] Revert "[JAX] Disable TE Norm Custom Calls" (#2035) Revert "[JAX] Disable TE Norm Custom Calls (#1993)" This reverts commit 6c970612715e2a493a2468256c05ce40a11e8556. --------- Signed-off-by: Phuong Nguyen Signed-off-by: Anton Korzh --- transformer_engine/jax/cpp_extensions/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 0d19785a0f..fcc2108cca 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -34,7 +34,7 @@ class BasePrimitive(metaclass=ABCMeta): _is_enabled = True # Default list of primitives to disable for all recipes - _default_disable_names = ["GemmPrimitive", "NormFwdPrimitive", "NormBwdPrimitive"] + _default_disable_names = ["GemmPrimitive"] @classmethod def enabled(cls): From 45d207b3506174383ad12aaaea4b7e9009b60ad6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Aug 2025 21:42:26 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Anton Korzh --- .../pytorch/distributed/test_linear_comms.py | 192 +++++++++++------- .../include/transformer_engine/ubnext.h | 33 +-- transformer_engine/common/ubnext.cu | 2 +- .../common/util/pybind_helper.h | 1 + .../pytorch/cpp_extensions/__init__.py | 2 +- .../pytorch/cpp_extensions/symm_allocator.py | 139 ++++++++----- transformer_engine/pytorch/module/base.py | 8 +- .../pytorch/module/layernorm_linear.py | 34 +++- transformer_engine/pytorch/module/linear.py | 34 +++- 9 files changed, 286 insertions(+), 159 deletions(-) diff --git a/tests/pytorch/distributed/test_linear_comms.py b/tests/pytorch/distributed/test_linear_comms.py index 414d6c200b..8e5ddd3843 100644 --- a/tests/pytorch/distributed/test_linear_comms.py +++ b/tests/pytorch/distributed/test_linear_comms.py @@ -12,134 +12,165 @@ import uuid import math + def main(): # Parse command-line arguments - parser = argparse.ArgumentParser(description="Run a linear layer with Transformer Engine, CUDA Graphs, and Tensor Parallelism") - parser.add_argument('--in_features', type=int, default=8192, help='Input feature size') - parser.add_argument('--out_features', type=int, default=8192, help='Output feature size') - parser.add_argument('--batch_size', type=int, default=2048, help='Batch size') - parser.add_argument('--cuda_graph', action='store_true', help='Use CUDA Graphs (pass this flag to enable)') - parser.add_argument('--validate', action='store_true', help='Validate allreduce ubnext') - parser.add_argument('--comm_type', type=str, default="sym", help='Comm type: nccl,sym,ub') - parser.add_argument('--sym_type', type=str, default="multimem_all_reduce", help='sym type: one_shot, two_shot, multimem_all_reduce, ub_custom') - parser.add_argument('--iterations', type=int, default=1000, help='Number of iterations') - parser.add_argument('--tp_size', type=int, default=None, help='Tensor parallelism size (defaults to number of GPUs)') + parser = argparse.ArgumentParser( + description=( + "Run a linear layer with Transformer Engine, CUDA Graphs, and Tensor Parallelism" + ) + ) + parser.add_argument("--in_features", type=int, default=8192, help="Input feature size") + parser.add_argument("--out_features", type=int, default=8192, help="Output feature size") + parser.add_argument("--batch_size", type=int, default=2048, help="Batch size") + parser.add_argument( + "--cuda_graph", action="store_true", help="Use CUDA Graphs (pass this flag to enable)" + ) + parser.add_argument("--validate", action="store_true", help="Validate allreduce ubnext") + parser.add_argument("--comm_type", type=str, default="sym", help="Comm type: nccl,sym,ub") + parser.add_argument( + "--sym_type", + type=str, + default="multimem_all_reduce", + help="sym type: one_shot, two_shot, multimem_all_reduce, ub_custom", + ) + parser.add_argument("--iterations", type=int, default=1000, help="Number of iterations") + parser.add_argument( + "--tp_size", + type=int, + default=None, + help="Tensor parallelism size (defaults to number of GPUs)", + ) args = parser.parse_args() # Check CUDA availability and get device count if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available. Test requires NVIDIA GPUs.") - + num_devices = torch.cuda.device_count() if num_devices == 0: raise RuntimeError("No CUDA devices found.") - + # Set tensor parallelism size - tp_size = args.tp_size if args.tp_size is not None else int(os.environ.get('WORLD_SIZE', num_devices)) - + tp_size = ( + args.tp_size if args.tp_size is not None else int(os.environ.get("WORLD_SIZE", num_devices)) + ) + # Initialize distributed environment for each GPU - myrank = int(os.environ.get('RANK', 0)) - local_rank = int(os.environ.get('LOCAL_RANK', 0)) - world_size = int(os.environ.get('WORLD_SIZE', 1)) - local_size = int(os.environ.get('LOCAL_WORLD_SIZE', str(torch.cuda.device_count()))) + myrank = int(os.environ.get("RANK", 0)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) num_nodes = world_size // local_size if num_nodes > 1: - assert ("MASTER_ADDR" in os.environ - ), "Multi-node run requires MASTER_ADDR to be set in the environment." + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." # Set device device = torch.device(f"cuda:{local_rank}") # Initialize torch.distributed for tensor parallelism # Only set defaults if not already set by torchrun - if 'MASTER_ADDR' not in os.environ: - os.environ['MASTER_ADDR'] = 'localhost' - if 'MASTER_PORT' not in os.environ: - os.environ['MASTER_PORT'] = '29500' + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "29500" torch.cuda.set_device(device) torch.distributed.init_process_group( - backend='nccl', - world_size=tp_size, - rank=myrank % tp_size, - device_id=device + backend="nccl", world_size=tp_size, rank=myrank % tp_size, device_id=device ) torch.distributed.barrier(group=torch.distributed.group.WORLD) # Transformer Engine handles tensor parallelism internally when distributed is initialized # Set environment variable for tensor parallelism size - os.environ['NVTE_TP_SIZE'] = str(tp_size) - + os.environ["NVTE_TP_SIZE"] = str(tp_size) + ub_cfgs = { - "proj_fprop": {"method": "pipeline","num_splits":1,"is_reduce_scatter":True,"num_sm":32,"atomic_gemm":False,"aggregate":False, - "cga_size":4,"set_sm_margin":False,"fp8_buf":False,"use_ce":False} - } + "proj_fprop": { + "method": "pipeline", + "num_splits": 1, + "is_reduce_scatter": True, + "num_sm": 32, + "atomic_gemm": False, + "aggregate": False, + "cga_size": 4, + "set_sm_margin": False, + "fp8_buf": False, + "use_ce": False, + } + } # Initialize model with BF16 precision - + modelseq = te.Linear( - in_features=int(args.in_features/tp_size), + in_features=int(args.in_features / tp_size), out_features=args.out_features, bias=False, device=device, - params_dtype=torch.bfloat16 + params_dtype=torch.bfloat16, ) - if (args.comm_type=='sym' and os.environ.get("NVTE_USE_UB_FOR_UBNEXT")) or args.comm_type=='ub': + if ( + args.comm_type == "sym" and os.environ.get("NVTE_USE_UB_FOR_UBNEXT") + ) or args.comm_type == "ub": te.module.base.initialize_ub( - [args.batch_size,args.out_features], - tp_size, - use_fp8=False, - dtype=torch.bfloat16, - bootstrap_backend="nccl", - ub_cfgs=ub_cfgs + [args.batch_size, args.out_features], + tp_size, + use_fp8=False, + dtype=torch.bfloat16, + bootstrap_backend="nccl", + ub_cfgs=ub_cfgs, ) modelpar = None - if args.comm_type=='sym' or args.comm_type=='nccl' : + if args.comm_type == "sym" or args.comm_type == "nccl": modelpar = te.Linear( in_features=args.in_features, out_features=args.out_features, bias=False, device=device, params_dtype=torch.bfloat16, - tp_size = tp_size, + tp_size=tp_size, parallel_mode="row", tp_group=torch.distributed.group.WORLD, - symmetric_ar_type=args.sym_type if args.comm_type=='sym' else None + symmetric_ar_type=args.sym_type if args.comm_type == "sym" else None, ) - if(args.comm_type=='ub'): + if args.comm_type == "ub": modelpar = te.Linear( in_features=args.in_features, out_features=args.out_features, bias=False, device=device, params_dtype=torch.bfloat16, - tp_size = tp_size, + tp_size=tp_size, parallel_mode="row", tp_group=torch.distributed.group.WORLD, sequence_parallel=True, ub_overlap_rs=True, - ub_name="proj" + ub_name="proj", ) # Create CUDA stream stream = torch.cuda.Stream() # Check for environment variable to override pool size - + allocator = None if args.comm_type == "sym" and args.validate: - pool_size = int(os.environ.get('NVTE_UB_SYMM_POOL_SIZE', 64)) * 1024 * 1024 - allocator = te.cpp_extensions.symm_allocator.SymmAllocator(pool_size, torch.device(device),torch.distributed.group.WORLD) + pool_size = int(os.environ.get("NVTE_UB_SYMM_POOL_SIZE", 64)) * 1024 * 1024 + allocator = te.cpp_extensions.symm_allocator.SymmAllocator( + pool_size, torch.device(device), torch.distributed.group.WORLD + ) # Run tensor comparison tests only for symmetric communication if args.comm_type == "sym" and args.validate: - + # Test different tensor sizes from 64 to 1024*1024 elements all_max_deltas = [] all_num_different = [] all_total_elements = [] all_sizes = [] - + size = 64 while size <= 1024 * 1024: # Allocate tensors @@ -154,48 +185,60 @@ def main(): abs_diff = torch.abs(tlamport - tmain) max_delta = torch.max(abs_diff).item() num_different = torch.sum(tlamport != tmain).item() - + # Store statistics all_max_deltas.append(max_delta) all_num_different.append(num_different) all_total_elements.append(tlamport.numel()) all_sizes.append(size) - + # Free tensor (memory returned to pool) del t, tlamport, tmain, abs_diff - + # Double the size for next iteration size *= 2 - + # Print summary statistics if myrank == 0: print("\n=== Tensor Comparison Summary ===") total_elements_tested = sum(all_total_elements) total_different_elements = sum(all_num_different) overall_max_delta = max(all_max_deltas) - - print(f"Tested sizes: {len(all_sizes)} different tensor sizes from {all_sizes[0]} to {all_sizes[-1]} elements") + + print( + f"Tested sizes: {len(all_sizes)} different tensor sizes from {all_sizes[0]} to" + f" {all_sizes[-1]} elements" + ) print(f"Total elements tested: {total_elements_tested}") print(f"Total different elements: {total_different_elements}") - print(f"Overall error rate: {(total_different_elements / total_elements_tested) * 100:.6f}%") + print( + "Overall error rate:" + f" {(total_different_elements / total_elements_tested) * 100:.6f}%" + ) print(f"Maximum delta across all tests: {overall_max_delta}") - + if total_different_elements > 0 or overall_max_delta > 0: print("\nPer-size breakdown:") for i, size in enumerate(all_sizes): error_rate = (all_num_different[i] / all_total_elements[i]) * 100 - print(f" Size {size:7d}: {all_num_different[i]:6d}/{all_total_elements[i]:7d} different ({error_rate:6.3f}%), max_delta: {all_max_deltas[i]:.6f}") + print( + f" Size {size:7d}:" + f" {all_num_different[i]:6d}/{all_total_elements[i]:7d} different" + f" ({error_rate:6.3f}%), max_delta: {all_max_deltas[i]:.6f}" + ) print("================================\n") - torch.distributed.barrier(group=torch.distributed.group.WORLD) torch.cuda.synchronize() - + for logbatch in range(int(math.log2(args.batch_size)) + 1): batch = 2**logbatch - if args.comm_type=='ub' and batch < tp_size: batch = tp_size + if args.comm_type == "ub" and batch < tp_size: + batch = tp_size # Create input tensor - inp = torch.randn(batch, int(args.in_features/tp_size), device=device, dtype=torch.bfloat16) + inp = torch.randn( + batch, int(args.in_features / tp_size), device=device, dtype=torch.bfloat16 + ) # Warm-up run modelseq(inp) modelpar(inp) @@ -204,11 +247,11 @@ def main(): with torch.cuda.stream(stream): # Create CUDA Graph gseq = torch.cuda.CUDAGraph() - gpar = torch.cuda.CUDAGraph() + gpar = torch.cuda.CUDAGraph() with torch.cuda.graph(gseq): output = modelseq(inp) with torch.cuda.graph(gpar): - output = modelpar(inp) + output = modelpar(inp) # Warm-up the graph for _ in range(5): gseq.replay() @@ -231,7 +274,7 @@ def main(): torch.cuda.synchronize() end_time = time.time() seq_elapsed = end_time - start_time - + torch.distributed.barrier(group=torch.distributed.group.WORLD) torch.distributed.barrier(group=torch.distributed.group.WORLD) torch.cuda.synchronize() @@ -248,10 +291,12 @@ def main(): torch.cuda.synchronize() end_time = time.time() par_elapsed = end_time - start_time - nccl_elapsed = (par_elapsed-seq_elapsed) + nccl_elapsed = par_elapsed - seq_elapsed # Calculate and print elapsed time (only on rank 0) if myrank == 0: - print(f"Batch{batch},{(seq_elapsed/ args.iterations) * 1e6:.4f}us,{(par_elapsed/ args.iterations) * 1e6:.4f} us,{(nccl_elapsed/ args.iterations) * 1e6:.4f}") + print( + f"Batch{batch},{(seq_elapsed/ args.iterations) * 1e6:.4f}us,{(par_elapsed/ args.iterations) * 1e6:.4f} us,{(nccl_elapsed/ args.iterations) * 1e6:.4f}" + ) if args.cuda_graph: # needed or NCCL would hang del gseq, gpar @@ -259,7 +304,8 @@ def main(): # Cleanup torch.distributed.destroy_process_group() + if __name__ == "__main__": # Generate a unique run ID for distributed initialization - os.environ['RUN_ID'] = str(uuid.uuid4()) + os.environ["RUN_ID"] = str(uuid.uuid4()) main() diff --git a/transformer_engine/common/include/transformer_engine/ubnext.h b/transformer_engine/common/include/transformer_engine/ubnext.h index d0d2b8c2ac..116c602e9f 100644 --- a/transformer_engine/common/include/transformer_engine/ubnext.h +++ b/transformer_engine/common/include/transformer_engine/ubnext.h @@ -4,25 +4,28 @@ * See LICENSE for license information. ************************************************************************/ - #ifndef TRANSFORMER_ENGINE_UBNEXT_H_ - #define TRANSFORMER_ENGINE_UBNEXT_H_ +#ifndef TRANSFORMER_ENGINE_UBNEXT_H_ +#define TRANSFORMER_ENGINE_UBNEXT_H_ - #include "transformer_engine.h" +#include "transformer_engine.h" - namespace transformer_engine { +namespace transformer_engine { - #ifdef __cplusplus - extern "C" { - #endif +#ifdef __cplusplus +extern "C" { +#endif - void allreduce_2shot_mc(int ranks, int myrank, void* uc0ptr,void* mc0ptr, void* mcptr_in,void* mcptr_out, size_t bytes,cudaStream_t stream); - void allreduce_2shot_mc_lamport(int ranks, int myrank,void* uc0ptr,void* mc0ptr,void* ucptr_out, void* mcptr_in, - void* mcptr_out,void* clear_ptr,size_t bytes,bool poisoned,cudaStream_t stream); - void allreduce_2shot_uc(int ranks, int myrank,void* uc0ptr,void* ucptr_in, void* ucptr_out,size_t bytes,cudaStream_t stream); +void allreduce_2shot_mc(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* mcptr_in, + void* mcptr_out, size_t bytes, cudaStream_t stream); +void allreduce_2shot_mc_lamport(int ranks, int myrank, void* uc0ptr, void* mc0ptr, void* ucptr_out, + void* mcptr_in, void* mcptr_out, void* clear_ptr, size_t bytes, + bool poisoned, cudaStream_t stream); +void allreduce_2shot_uc(int ranks, int myrank, void* uc0ptr, void* ucptr_in, void* ucptr_out, + size_t bytes, cudaStream_t stream); - #ifdef __cplusplus - } - #endif +#ifdef __cplusplus } +#endif +} // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/ubnext.cu b/transformer_engine/common/ubnext.cu index 15b80f51e4..3be15d8dab 100644 --- a/transformer_engine/common/ubnext.cu +++ b/transformer_engine/common/ubnext.cu @@ -395,4 +395,4 @@ extern "C" void allreduce_2shot_mc_lamport(int ranks, int myrank, void *uc0ptr, CUDACHECK( cudaLaunchKernelExC(&cfg, (void *)(userbuffers_fp16_sum_inplace_gpu_mc_lamport), kernelArgs)); } -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index efcca73222..cb2a7ab6c3 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -12,6 +12,7 @@ #include #include #include + #include "cuda_runtime.h" #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index 075eea0158..07d150b0f0 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -7,4 +7,4 @@ from .fused_attn import * from .gemm import * -from .symm_allocator import * \ No newline at end of file +from .symm_allocator import * diff --git a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py index 8c10b9218a..72f3bf47b5 100644 --- a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py +++ b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py @@ -7,6 +7,7 @@ import torch.distributed._symmetric_memory as symm_mem from ctypes import pythonapi, c_void_p, py_object + def to_capsule(ptr): # Set the return type to py_object to get a Python object (PyCapsule) pythonapi.PyCapsule_New.restype = py_object @@ -15,26 +16,37 @@ def to_capsule(ptr): capsule = pythonapi.PyCapsule_New(ptr, None, None) return capsule + class SymmTensor(torch.Tensor): """Custom tensor subclass that uses custom memory""" + @staticmethod - def __new__(cls, pool: torch.Tensor, offset: int, shape: torch.Size, dtype: torch.dtype, allocator: 'SymmAllocator'): + def __new__( + cls, + pool: torch.Tensor, + offset: int, + shape: torch.Size, + dtype: torch.dtype, + allocator: "SymmAllocator", + ): # Calculate number of elements and bytes num_elements = torch.Size(shape).numel() element_size = torch.tensor(0, dtype=dtype).element_size() nbytes = element_size * num_elements - + # Validate pool assert pool.dtype == torch.uint8, f"Expected uint8 pool, got {pool.dtype}" - assert pool.numel() >= offset + nbytes, f"Pool too small: {pool.numel()} bytes, need {offset + nbytes}" - + assert ( + pool.numel() >= offset + nbytes + ), f"Pool too small: {pool.numel()} bytes, need {offset + nbytes}" + # Slice the pool to get the required bytes - byte_slice = pool[offset:offset + nbytes] - + byte_slice = pool[offset : offset + nbytes] + # Reinterpret the uint8 bytes as the target dtype tensor = byte_slice.view(dtype=dtype) tensor = tensor.view(*shape) - + # Initialize as a subclass of torch.Tensor self = torch.Tensor._make_subclass(cls, tensor) if not isinstance(allocator, SymmAllocator): @@ -44,50 +56,53 @@ def __new__(cls, pool: torch.Tensor, offset: int, shape: torch.Size, dtype: torc self._offset = offset self._size = nbytes return self - + def __del__(self): """Custom deallocator to return memory to the pool.""" - if hasattr(self, '_allocator') and hasattr(self, '_ptr'): + if hasattr(self, "_allocator") and hasattr(self, "_ptr"): self._allocator.free(self._ptr) + class SymmAllocator: def __init__(self, size_bytes: int, device: torch.device, dist_group: torch.distributed.group): """Initialize the allocator with a preallocated memory pool.""" # Preallocate the memory pool using torch.empty - self.reg0_size = 1024 # NVL72*8 plus up to 112 flags + self.reg0_size = 1024 # NVL72*8 plus up to 112 flags self.device = device self.world_size = torch.distributed.get_world_size(dist_group) self.myrank = torch.distributed.get_rank(dist_group) self.dist_group = dist_group from ..module.base import get_ub + if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): self.ub_obj = get_ub("ubnext") self.internal_pool = self.ub_obj.get_buffer(False).reshape(-1) self.mc0_ptr = self.ub_obj.init_ubnext() self.pool_size = self.internal_pool.numel() else: - alignment = 2 * 1024 * 1024 # memory is allocated in 2MB pages anyways + alignment = 2 * 1024 * 1024 # memory is allocated in 2MB pages anyways self.pool_size = int((size_bytes + alignment - 1) / alignment) * alignment self.internal_pool = symm_mem.empty(self.pool_size, dtype=torch.uint8, device=device) self.hdl0 = symm_mem.rendezvous(self.internal_pool, dist_group) self.mc0_ptr = self.hdl0.multicast_ptr self.internal_pool.fill_(0) - self.internal_pool.view(torch.int64)[:self.world_size].copy_(torch.tensor(self.hdl0.buffer_ptrs).view(torch.int64)) - #self.hdl0.barrier(channel=0) + self.internal_pool.view(torch.int64)[: self.world_size].copy_( + torch.tensor(self.hdl0.buffer_ptrs).view(torch.int64) + ) + # self.hdl0.barrier(channel=0) # Synchronize all processes before proceeding torch.distributed.barrier(group=dist_group) - + # Track the raw pointer to the pool self.pool_ptr = self.internal_pool.data_ptr() # Track allocated segments: (offset, size) self.allocated: List[Tuple[int, int]] = [] # Track free segments: (offset, size) - self.freelist: List[Tuple[int, int]] = [(self.reg0_size, self.pool_size-self.reg0_size)] + self.freelist: List[Tuple[int, int]] = [(self.reg0_size, self.pool_size - self.reg0_size)] self.nextpoisoned = None self.tensors = weakref.WeakSet() self.lock = Lock() - def allocate(self, nbytes: int) -> Tuple[Optional[int], Optional[torch.Tensor]]: """Allocate nbytes from the pool, returning a pointer and pool reference.""" @@ -99,12 +114,13 @@ def allocate(self, nbytes: int) -> Tuple[Optional[int], Optional[torch.Tensor]]: if size > nbytes: self.freelist.append((offset + nbytes, size - nbytes)) return self.pool_ptr + offset, self.internal_pool - return None,None + return None, None - # No suitable free segment found - raise MemoryError(f"Preallocated pool exhausted: requested {nbytes} bytes, " - f"available segments: {self.freelist}") + raise MemoryError( + f"Preallocated pool exhausted: requested {nbytes} bytes, " + f"available segments: {self.freelist}" + ) def free(self, ptr: int): """Free the memory at ptr, returning it to the pool.""" @@ -119,7 +135,7 @@ def free(self, ptr: int): return # Ignore invalid pointers silently pass - + raise ValueError(f"Invalid pointer {ptr} not found in allocated segments") def _merge_free_segments(self): @@ -139,24 +155,27 @@ def _merge_free_segments(self): merged.append((current_offset, current_size)) self.freelist = merged - def create_tensor(self, shape: Tuple[int, ...], dtype: torch.dtype = torch.float32) -> Optional[torch.Tensor]: + def create_tensor( + self, shape: Tuple[int, ...], dtype: torch.dtype = torch.float32 + ) -> Optional[torch.Tensor]: """Create a PooledTensor using memory from the pool.""" nbytes = torch.tensor(0, dtype=dtype).element_size() * torch.Size(shape).numel() ptr, pool = self.allocate(nbytes) - if ptr is None: return None + if ptr is None: + return None offset = ptr - self.pool_ptr tensor = SymmTensor(pool, offset, torch.Size(shape), dtype, self) self.tensors.add(tensor) return tensor - + def allreduce_uc(self, tensor_in: torch.Tensor) -> torch.Tensor: """Performs in-place allreduce on the given SymmTensor using best algo""" assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" - #tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + # tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) ucptr_in = tensor_in.data_ptr() - #mcptr_out = tensor_out.data_ptr() + # mcptr_out = tensor_out.data_ptr() nbytes = tensor_in.numel() * tensor_in.element_size() # Import your pybind module if not imported @@ -167,19 +186,19 @@ def allreduce_uc(self, tensor_in: torch.Tensor) -> torch.Tensor: self.myrank, to_capsule(self.internal_pool.data_ptr()), to_capsule(ucptr_in), - to_capsule(ucptr_in),#out - nbytes + to_capsule(ucptr_in), # out + nbytes, ) return tensor_in - + def allreduce(self, tensor_in: torch.Tensor) -> torch.Tensor: """Performs in-place allreduce on the given SymmTensor using best algo""" assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" - #tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) + # tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) mcptr_in = self.mc0_ptr + (tensor_in.data_ptr() - self.internal_pool.data_ptr()) - #mcptr_out = self.hdl.multicast_ptr + (tensor_out.data_ptr() - self.internal_pool.data_ptr()) + # mcptr_out = self.hdl.multicast_ptr + (tensor_out.data_ptr() - self.internal_pool.data_ptr()) nbytes = tensor_in.numel() * tensor_in.element_size() # Import your pybind module if not imported @@ -191,11 +210,11 @@ def allreduce(self, tensor_in: torch.Tensor) -> torch.Tensor: to_capsule(self.internal_pool.data_ptr()), to_capsule(self.mc0_ptr), to_capsule(mcptr_in), - to_capsule(mcptr_in),#out - nbytes + to_capsule(mcptr_in), # out + nbytes, ) return tensor_in - + def allreduce_lamport(self, tensor_in: torch.Tensor) -> torch.Tensor: """ Performs allreduce using 2-shot multicast Lamport variant: @@ -205,31 +224,33 @@ def allreduce_lamport(self, tensor_in: torch.Tensor) -> torch.Tensor: - Returns `tensor_out`. """ assert tensor_in.device == self.device, "Tensor device mismatch with allocator device" - if self.mc0_ptr is None or self.mc0_ptr == 0: return self.allreduce_uc(tensor_in) + if self.mc0_ptr is None or self.mc0_ptr == 0: + return self.allreduce_uc(tensor_in) from transformer_engine_torch import allreduce_2shot_mc_lamport # Allocate output tensor of same shape/dtype - tensor_out = self.nextpoisoned + tensor_out = self.nextpoisoned poisonedout = True - if self.nextpoisoned is None or self.nextpoisoned.shape!=tensor_in.shape: + if self.nextpoisoned is None or self.nextpoisoned.shape != tensor_in.shape: if self.nextpoisoned is not None: del self.nextpoisoned self.nextpoisoned = None tensor_out = self.create_tensor(tensor_in.shape, tensor_in.dtype) poisonedout = False - if tensor_out is None: return self.allreduce(tensor_in) + if tensor_out is None: + return self.allreduce(tensor_in) - # alllcate potential output for next allreduce (speculative) and poison it now + # alllcate potential output for next allreduce (speculative) and poison it now self.nextpoisoned = self.create_tensor(tensor_in.shape, tensor_in.dtype) - + # Calculate mcptr_in and mcptr_out with offset relative to internal_pool offset = tensor_in.data_ptr() - self.internal_pool.data_ptr() mcptr_in = self.mc0_ptr + offset mcptr_out = self.mc0_ptr + (tensor_out.data_ptr() - self.internal_pool.data_ptr()) # Use clear_ptr to clear output memory before reduction; here we use tensor_out - #clear_ptr = self.nextpoisoned.data_ptr() if self.nextpoisoned is not None else 0 + # clear_ptr = self.nextpoisoned.data_ptr() if self.nextpoisoned is not None else 0 nbytes = tensor_in.numel() * tensor_in.element_size() @@ -244,21 +265,27 @@ def allreduce_lamport(self, tensor_in: torch.Tensor) -> torch.Tensor: to_capsule(mcptr_out), to_capsule(self.nextpoisoned.data_ptr()) if self.nextpoisoned is not None else None, nbytes, - poisonedout + poisonedout, ) return tensor_out - -_allocator_map: Dict[torch.distributed.group, Tuple[int, 'SymmAllocator']] = {} -def ubsymm_request_allocator(dist_group: torch.distributed.group, shape:Optional[Tuple[int, ...]] = None, dtype: torch.dtype = torch.bfloat16) -> None: + +_allocator_map: Dict[torch.distributed.group, Tuple[int, "SymmAllocator"]] = {} + + +def ubsymm_request_allocator( + dist_group: torch.distributed.group, + shape: Optional[Tuple[int, ...]] = None, + dtype: torch.dtype = torch.bfloat16, +) -> None: if shape is not None: num_elements = torch.Size(shape).numel() element_size = torch.tensor(0, dtype=dtype).element_size() tensor_size = num_elements * element_size else: tensor_size = 0 - + if dist_group not in _allocator_map: if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): assert not _allocator_map, "Current UBNEXT-UB bypass supports only one process group." @@ -269,21 +296,27 @@ def ubsymm_request_allocator(dist_group: torch.distributed.group, shape:Optional max_size = max(old_size, tensor_size) _allocator_map[dist_group] = (max_size, None) -def ubsymm_get_sym_tensor(shape: Tuple[int, ...], dtype: torch.dtype, dist_group: torch.distributed.group) -> torch.Tensor: + +def ubsymm_get_sym_tensor( + shape: Tuple[int, ...], dtype: torch.dtype, dist_group: torch.distributed.group +) -> torch.Tensor: if dtype != torch.bfloat16: - return None # Unsupported dtype, do fallback to nccl - if dist_group not in _allocator_map: return None # No allocator requested earlier, do fallback to nccl + return None # Unsupported dtype, do fallback to nccl + if dist_group not in _allocator_map: + return None # No allocator requested earlier, do fallback to nccl (max_size, allocator) = _allocator_map[dist_group] if allocator is None: - new_max_size = int(os.environ.get('NVTE_UB_SYMM_POOL_SIZE', ((6*max_size+1048575)/1024/1024) ) ) + new_max_size = int( + os.environ.get("NVTE_UB_SYMM_POOL_SIZE", ((6 * max_size + 1048575) / 1024 / 1024)) + ) allocator = SymmAllocator( new_max_size * 1024 * 1024, - torch.device(f'cuda:{torch.cuda.current_device()}'), - dist_group + torch.device(f"cuda:{torch.cuda.current_device()}"), + dist_group, ) _allocator_map[dist_group] = (new_max_size, allocator) return allocator.create_tensor(shape, dtype) + def ubsymm_allreduce(tensor_in: SymmTensor) -> SymmTensor: return tensor_in._allocator.allreduce_lamport(tensor_in) - \ No newline at end of file diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 976afac5e0..61bf5ed9fa 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -264,7 +264,7 @@ def initialize_ub( # Add "ubnext" to bulk methods if environment variable is set if os.environ.get("NVTE_USE_UB_FOR_UBNEXT"): methods["bulk"].append("ubnext") - + # AG-RS overlap pairs of layers forming a tensor-parallel block ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} @@ -370,7 +370,11 @@ def add_ub( ) else: ub_obj = tex.CommOverlap( - shape if name != "ubnext" else (int(os.environ.get('NVTE_UB_SYMM_POOL_SIZE', 64)), 1024*1024), #Communication buffer shape + ( + shape + if name != "ubnext" + else (int(os.environ.get("NVTE_UB_SYMM_POOL_SIZE", 64)), 1024 * 1024) + ), # Communication buffer shape buffer_dtype if name != "ubnext" else torch.uint8, # Communication buffer data type helper, # Helper for torch.distributed callbacks during bootstrapping tp_size, # Tensor-parallel group size (may be different than local_size) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 14f5bf1c3e..9356c25bdc 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -76,7 +76,7 @@ general_gemm, ubsymm_request_allocator, ubsymm_get_sym_tensor, - ubsymm_allreduce + ubsymm_allreduce, ) __all__ = ["LayerNormLinear"] @@ -330,8 +330,15 @@ def forward( out_shape[-1] = out_features reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) symm_out = None - if symmetric_ar_type == 'ub_custom': - symm_out = ubsymm_get_sym_tensor( (list(inp.shape)[0], out_features,), activation_dtype, tp_group) + if symmetric_ar_type == "ub_custom": + symm_out = ubsymm_get_sym_tensor( + ( + list(inp.shape)[0], + out_features, + ), + activation_dtype, + tp_group, + ) # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T @@ -377,8 +384,14 @@ def forward( if symm_out is not None: out = ubsymm_allreduce(symm_out) else: - fallback_symmetric = "multimem_all_reduce" if symmetric_ar_type == "ub_custom" else symmetric_ar_type - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=fallback_symmetric) + fallback_symmetric = ( + "multimem_all_reduce" + if symmetric_ar_type == "ub_custom" + else symmetric_ar_type + ) + out, _ = symmetric_all_reduce( + out, tp_group, all_reduce_type=fallback_symmetric + ) else: out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") @@ -1247,8 +1260,15 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - if self.symmetric_ar_type == 'ub_custom': - ubsymm_request_allocator(self.tp_group, (int(os.environ.get('NVTE_UB_MAXBATCH',64)), self.out_features,), params_dtype) + if self.symmetric_ar_type == "ub_custom": + ubsymm_request_allocator( + self.tp_group, + ( + int(os.environ.get("NVTE_UB_MAXBATCH", 64)), + self.out_features, + ), + params_dtype, + ) self.eps = eps layer_norm_weight = torch.nn.Parameter( torch.empty(self.in_features, device=device, dtype=params_dtype) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2562d3e7d4..ca833141ba 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -55,7 +55,7 @@ general_gemm, ubsymm_request_allocator, ubsymm_get_sym_tensor, - ubsymm_allreduce + ubsymm_allreduce, ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo @@ -295,8 +295,15 @@ def forward( reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) symm_out = None - if symmetric_ar_type == 'ub_custom': - symm_out = ubsymm_get_sym_tensor( (list(inp.shape)[0], out_features,), activation_dtype, tp_group) + if symmetric_ar_type == "ub_custom": + symm_out = ubsymm_get_sym_tensor( + ( + list(inp.shape)[0], + out_features, + ), + activation_dtype, + tp_group, + ) # ------------------------------------------------------ # Forward GEMM # Note: y = x * w^T @@ -337,8 +344,14 @@ def forward( if symm_out is not None: out = ubsymm_allreduce(symm_out) else: - fallback_symmetric = "multimem_all_reduce" if symmetric_ar_type == "ub_custom" else symmetric_ar_type - out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=fallback_symmetric) + fallback_symmetric = ( + "multimem_all_reduce" + if symmetric_ar_type == "ub_custom" + else symmetric_ar_type + ) + out, _ = symmetric_all_reduce( + out, tp_group, all_reduce_type=fallback_symmetric + ) else: out, _ = allreduce(out, tp_group) nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") @@ -1164,8 +1177,15 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - if self.symmetric_ar_type == 'ub_custom': - ubsymm_request_allocator(self.tp_group, (int(os.environ.get('NVTE_UB_MAXBATCH',64)), self.out_features,), params_dtype) + if self.symmetric_ar_type == "ub_custom": + ubsymm_request_allocator( + self.tp_group, + ( + int(os.environ.get("NVTE_UB_MAXBATCH", 64)), + self.out_features, + ), + params_dtype, + ) # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() From 214b89a5e700250a0921b650afd98ae54705dab4 Mon Sep 17 00:00:00 2001 From: Anton Korzh Date: Thu, 7 Aug 2025 09:51:45 -0700 Subject: [PATCH 8/9] fix output shape Signed-off-by: Anton Korzh --- .../pytorch/cpp_extensions/symm_allocator.py | 2 +- transformer_engine/pytorch/module/layernorm_linear.py | 7 +++---- transformer_engine/pytorch/module/linear.py | 7 +++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py index 72f3bf47b5..0eae2818c4 100644 --- a/transformer_engine/pytorch/cpp_extensions/symm_allocator.py +++ b/transformer_engine/pytorch/cpp_extensions/symm_allocator.py @@ -156,7 +156,7 @@ def _merge_free_segments(self): self.freelist = merged def create_tensor( - self, shape: Tuple[int, ...], dtype: torch.dtype = torch.float32 + self, shape: torch.Size, dtype: torch.dtype = torch.float32 ) -> Optional[torch.Tensor]: """Create a PooledTensor using memory from the pool.""" nbytes = torch.tensor(0, dtype=dtype).element_size() * torch.Size(shape).numel() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9356c25bdc..913d823745 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -331,11 +331,10 @@ def forward( reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) symm_out = None if symmetric_ar_type == "ub_custom": + out_shape_list = list(tuple(inp.shape)) + out_shape_list[-1] = out_features symm_out = ubsymm_get_sym_tensor( - ( - list(inp.shape)[0], - out_features, - ), + torch.Size(out_shape_list), activation_dtype, tp_group, ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ca833141ba..600e436ac7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -296,11 +296,10 @@ def forward( symm_out = None if symmetric_ar_type == "ub_custom": + out_shape_list = list(tuple(inp.shape)) + out_shape_list[-1] = out_features symm_out = ubsymm_get_sym_tensor( - ( - list(inp.shape)[0], - out_features, - ), + torch.Size(out_shape_list), activation_dtype, tp_group, ) From 2b8b8c2bde59e4c74eb0fb5b76a7d4b799a620f8 Mon Sep 17 00:00:00 2001 From: Anton Korzh Date: Thu, 7 Aug 2025 14:18:49 -0700 Subject: [PATCH 9/9] tp1 fix Signed-off-by: Anton Korzh --- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++-- transformer_engine/pytorch/module/linear.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 913d823745..ea37cb332e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -330,7 +330,7 @@ def forward( out_shape[-1] = out_features reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) symm_out = None - if symmetric_ar_type == "ub_custom": + if symmetric_ar_type == "ub_custom" and parallel_mode == "row" and tp_size > 1: out_shape_list = list(tuple(inp.shape)) out_shape_list[-1] = out_features symm_out = ubsymm_get_sym_tensor( @@ -1259,7 +1259,7 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - if self.symmetric_ar_type == "ub_custom": + if self.symmetric_ar_type == "ub_custom" and parallel_mode == "row" and tp_size > 1: ubsymm_request_allocator( self.tp_group, ( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 600e436ac7..18fe6347f5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -295,7 +295,7 @@ def forward( reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device) symm_out = None - if symmetric_ar_type == "ub_custom": + if symmetric_ar_type == "ub_custom" and parallel_mode == "row" and tp_size > 1: out_shape_list = list(tuple(inp.shape)) out_shape_list[-1] = out_features symm_out = ubsymm_get_sym_tensor( @@ -1176,7 +1176,7 @@ def __init__( 7, 0, ), "Torch version must be at least 2.7 to use symmetric memory" - if self.symmetric_ar_type == "ub_custom": + if self.symmetric_ar_type == "ub_custom" and parallel_mode == "row" and tp_size > 1: ubsymm_request_allocator( self.tp_group, (