diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index 5cdab7ba7e0..388819b957a 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -361,7 +361,9 @@ void initBindings(nb::module_& m) nb::call_guard()); nb::class_(m, "McastGPUBuffer") - .def(nb::init(), nb::call_guard()) + .def(nb::init(), nb::arg("buf_size"), + nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"), + nb::arg("mn_nvlink"), nb::call_guard()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer, nb::call_guard()) .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer, diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 574249b6a23..469aafe6476 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -455,7 +455,9 @@ void initBindings(pybind11::module_& m) py::call_guard()); py::class_(m, "McastGPUBuffer") - .def(py::init(), py::call_guard()) + .def(py::init(), py::arg("buf_size"), + py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"), + py::arg("mn_nvlink"), py::call_guard()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer, py::call_guard()) .def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer, diff --git a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp index 950215e7542..9be590c7fce 100644 --- a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp +++ b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp @@ -20,7 +20,7 @@ #include "tensorrt_llm/common/cudaDriverWrapper.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/runtime/utils/mpiUtils.h" + #include #include #include @@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran) } // namespace McastDeviceMemory::McastDeviceMemory( - size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink) + size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink) : mIsMNNvlink(mnNvlink) , mDeviceIdx(deviceIdx) , mGroupSize(groupSize) @@ -48,6 +48,7 @@ McastDeviceMemory::McastDeviceMemory( , mAllocationSize(0) , mMcPtr(0) , mMcHandle(0) + , mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank)) { TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx)); @@ -62,9 +63,12 @@ McastDeviceMemory::McastDeviceMemory( // From pytorch implementation for alignment constexpr size_t kSignalPadAlignment = 16UL; mSignalPadOffset = roundUp(mBufSize, kSignalPadAlignment); + int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()}; + TLLM_LOG_DEBUG( - "[McastDeviceMemory] Rank: %u, Group size: %u, isMultiNode: %d, device_idx: %d, Signal pad offset: %zu", - mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset); + "[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, " + "device_idx: %d, Signal pad offset: %zu", + world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset); if (mIsMNNvlink) { @@ -127,9 +131,6 @@ McastDeviceMemory::~McastDeviceMemory() void McastDeviceMemory::allocMnMcastMem(size_t bufSize) { - - auto const& mpi_comm = tensorrt_llm::mpi::MpiComm::session(); - CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC; CUmemAllocationProp prop = {}; prop.requestedHandleTypes = handle_type; @@ -156,7 +157,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize) // All gather cudaMallocHost(&exphndl, mGroupSize * sizeof(CUmemFabricHandle)); memcpy(exphndl + mGroupRank * sizeof(CUmemFabricHandle), &myhndl, sizeof(CUmemFabricHandle)); - mpi_comm.allgather( + mGroupComm.allgather( exphndl + mGroupRank * sizeof(CUmemFabricHandle), exphndl, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR); cudaDeviceSynchronize(); @@ -175,7 +176,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize) TLLM_CU_CHECK(cuMemExportToShareableHandle((void*) fabric_handle, mMcHandle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); } // Broadcast - mpi_comm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0); + mGroupComm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0); cudaDeviceSynchronize(); if (mGroupRank != 0) { @@ -210,12 +211,9 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize) void McastDeviceMemory::allocNvlsMcastMem(size_t bufSize) { - // Create a std::set to include all ranks in range (0, group_size) - std::set ranks; - for (uint32_t i = 0; i < mGroupSize; ++i) - { - ranks.insert(i); - } + // Get the world ranks for ranks in this group + auto ranks_ = tensorrt_llm::mpi::getWorldRanks(mGroupComm); + std::set ranks(ranks_.begin(), ranks_.end()); // Reuse existing implementation mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate(bufSize, ranks); mMcHandle = mNvlsHandle->mc_handle; diff --git a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h index 4afcc05223d..d9428b4126c 100644 --- a/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h +++ b/cpp/tensorrt_llm/runtime/mcastDeviceMemory.h @@ -17,6 +17,7 @@ #include "tensorrt_llm/common/mcastDevMemUtils.h" #include "tensorrt_llm/runtime/ipcNvlsMemory.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include #include @@ -42,7 +43,8 @@ class McastDeviceMemory McastDeviceMemory(McastDeviceMemory const&) = delete; McastDeviceMemory& operator=(McastDeviceMemory const&) = delete; - McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink); + McastDeviceMemory( + size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink); // We don't register the pointer in these two functions since we don't expect any python-level code would call // to obtain the raw pointers. @@ -98,6 +100,8 @@ class McastDeviceMemory CUmemGenericAllocationHandle mMcHandle; std::vector mUcHandles; + tensorrt_llm::mpi::MpiComm mGroupComm; //!< The MPI communicator for the group + // Host array of pointers std::vector mUcPtrs; std::vector mSignalPads; diff --git a/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h b/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h index 941ddb1a46a..4c011a790ba 100644 --- a/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h +++ b/cpp/tensorrt_llm/runtime/mcastGPUBuffer.h @@ -34,12 +34,14 @@ class McastGPUBuffer //! \param bufSize The total size of the buffer in bytes. //! \param groupSize The number of ranks in the communication group. //! \param groupRank The rank of the local process within the group. + //! \param splitColor The color of the split for topology split. //! \param device The CUDA device for buffer allocation. //! \param mnNvlink Flag indicating if multi-node NVLink is used. - McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, at::Device device, bool mnNvlink) - : mMcastDeviceMemory(bufSize, groupSize, groupRank, device.index(), mnNvlink) + McastGPUBuffer( + size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink) + : mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink) , mBufSize(bufSize) - , mLocalDevice(device) + , mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx)) { } @@ -49,7 +51,7 @@ class McastGPUBuffer //! \param dtype The data type of the tensor elements. //! \param storageOffset The offset in elements from the start of the buffer. //! \return An ATen tensor wrapping the unicast buffer section. - at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset) + at::Tensor getUCBuffer(uint32_t rank, std::vector sizes, torch::ScalarType dtype, int64_t storageOffset) { size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies()); size_t const elementSize = c10::elementSize(dtype); @@ -59,7 +61,10 @@ class McastGPUBuffer auto* dataPtr = static_cast(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize; auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice); - return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor(); + return at::for_blob(dataPtr, c10::IntArrayRef(sizes)) + .options(options) + .target_device(mLocalDevice) + .make_tensor(); } //! \brief Returns a PyTorch tensor view of the multicast buffer portion. @@ -67,7 +72,7 @@ class McastGPUBuffer //! \param dtype The data type of the tensor elements. //! \param storageOffset The offset in elements from the start of the buffer. //! \return An ATen tensor wrapping the multicast buffer section. - at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset) + at::Tensor getMCBuffer(std::vector sizes, torch::ScalarType dtype, int64_t storageOffset) { size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies()); size_t const elementSize = c10::elementSize(dtype); @@ -77,7 +82,10 @@ class McastGPUBuffer auto* dataPtr = static_cast(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize; auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice); - return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor(); + return at::for_blob(dataPtr, c10::IntArrayRef(sizes)) + .options(options) + .target_device(mLocalDevice) + .make_tensor(); } private: diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index b3811204dfa..c5749681040 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -1,4 +1,3 @@ -import logging import math import os import platform @@ -8,7 +7,7 @@ import torch from torch import nn -from tensorrt_llm._utils import mpi_barrier +from tensorrt_llm._utils import mpi_comm from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, AllReduceStrategy, MoEAllReduceParams) @@ -17,7 +16,6 @@ from tensorrt_llm.plugin.plugin import CustomAllReduceHelper _thread_local = threading.local() -logger = logging.getLogger(__name__) def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor: @@ -55,11 +53,15 @@ def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None: def get_allreduce_mnnvl_workspace( mapping: Mapping, dtype: torch.dtype ) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]: + if not hasattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'): setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}', {}) - + # Support topology split + comm = mpi_comm().Split( + int(mapping.pp_rank * mapping.cp_size + mapping.cp_rank), + mapping.tp_rank) force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" allreduce_mnnvl_workspaces = getattr( @@ -77,7 +79,9 @@ def get_allreduce_mnnvl_workspace( buffer_size_in_bytes, mapping.tp_size, mapping.tp_rank, - torch.device("cuda", mapping.local_rank), + # Split the communicator according to the topology + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, + mapping.local_rank, True, # mnNvlink ) @@ -87,7 +91,7 @@ def get_allreduce_mnnvl_workspace( buffer.fill_(-0.0) # CPU barrier since we assume this should not be called in cuda graph torch.cuda.synchronize() - mpi_barrier() + comm.Barrier() # This is a buffer to maintain the state of this allreduce Op # Should have the same lifetime with self._buffer @@ -458,12 +462,7 @@ def __init__(self, # Initialize MNNVL AllReduce if needed if self.strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): - if self.mapping.tp_size != self.mapping.world_size: - logger.debug( - f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} " - f"!= world_size:{self.mapping.world_size}") - self.mnnvl_allreduce = None - elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype): + if MNNVLAllReduce.is_mnnvl(self.mapping, dtype): try: self.mnnvl_allreduce = MNNVLAllReduce( self.mapping, dtype) if dtype else None diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 0ca4d28085b..09b42c6fee4 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -771,12 +771,11 @@ def _compute_mlp_tp_size(self, intermediate_size: int, self.mapping.tp_size, ) - if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl( - ): + if tp > self.mapping.gpus_per_node: mlp_tp_size = math.gcd( tp, self.mapping.gpus_per_node, - ) # Avoid costly inter-node TP when MNNVL is not supported + ) # Avoid costly inter-node TP else: mlp_tp_size = tp return mlp_tp_size