Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/nanobind/runtime/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ void initBindings(nb::module_& m)
nb::call_guard<nb::gil_scoped_release>());

nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
.def(nb::init<size_t, uint32_t, uint32_t, at::Device, bool>(), nb::call_guard<nb::gil_scoped_release>())
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), nb::arg("buf_size"),
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"),
nb::arg("mn_nvlink"), nb::call_guard<nb::gil_scoped_release>())
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
nb::call_guard<nb::gil_scoped_release>())
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,
Expand Down
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/pybind/runtime/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,9 @@ void initBindings(pybind11::module_& m)
py::call_guard<py::gil_scoped_release>());

py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
.def(py::init<size_t, uint32_t, uint32_t, at::Device, bool>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), py::arg("buf_size"),
py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"),
py::arg("mn_nvlink"), py::call_guard<py::gil_scoped_release>())
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
py::call_guard<py::gil_scoped_release>())
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,
Expand Down
28 changes: 13 additions & 15 deletions cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"

#include <cstddef>
#include <cstdint>
#include <cuda_runtime_api.h>
Expand All @@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran)
} // namespace

McastDeviceMemory::McastDeviceMemory(
size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink)
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink)
: mIsMNNvlink(mnNvlink)
, mDeviceIdx(deviceIdx)
, mGroupSize(groupSize)
Expand All @@ -48,6 +48,7 @@ McastDeviceMemory::McastDeviceMemory(
, mAllocationSize(0)
, mMcPtr(0)
, mMcHandle(0)
, mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank))
{

TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
Expand All @@ -62,9 +63,12 @@ McastDeviceMemory::McastDeviceMemory(
// From pytorch implementation for alignment
constexpr size_t kSignalPadAlignment = 16UL;
mSignalPadOffset = roundUp(mBufSize, kSignalPadAlignment);
int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()};

TLLM_LOG_DEBUG(
"[McastDeviceMemory] Rank: %u, Group size: %u, isMultiNode: %d, device_idx: %d, Signal pad offset: %zu",
mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
"[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, "
"device_idx: %d, Signal pad offset: %zu",
world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);

if (mIsMNNvlink)
{
Expand Down Expand Up @@ -127,9 +131,6 @@ McastDeviceMemory::~McastDeviceMemory()

void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
{

auto const& mpi_comm = tensorrt_llm::mpi::MpiComm::session();

CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC;
CUmemAllocationProp prop = {};
prop.requestedHandleTypes = handle_type;
Expand All @@ -156,7 +157,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
// All gather
cudaMallocHost(&exphndl, mGroupSize * sizeof(CUmemFabricHandle));
memcpy(exphndl + mGroupRank * sizeof(CUmemFabricHandle), &myhndl, sizeof(CUmemFabricHandle));
mpi_comm.allgather(
mGroupComm.allgather(
exphndl + mGroupRank * sizeof(CUmemFabricHandle), exphndl, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR);
cudaDeviceSynchronize();

Expand All @@ -175,7 +176,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
TLLM_CU_CHECK(cuMemExportToShareableHandle((void*) fabric_handle, mMcHandle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
}
// Broadcast
mpi_comm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
mGroupComm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
cudaDeviceSynchronize();
if (mGroupRank != 0)
{
Expand Down Expand Up @@ -210,12 +211,9 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)

void McastDeviceMemory::allocNvlsMcastMem(size_t bufSize)
{
// Create a std::set to include all ranks in range (0, group_size)
std::set<int> ranks;
for (uint32_t i = 0; i < mGroupSize; ++i)
{
ranks.insert(i);
}
// Get the world ranks for ranks in this group
auto ranks_ = tensorrt_llm::mpi::getWorldRanks(mGroupComm);
std::set<int> ranks(ranks_.begin(), ranks_.end());
// Reuse existing implementation
mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate(bufSize, ranks);
mMcHandle = mNvlsHandle->mc_handle;
Expand Down
6 changes: 5 additions & 1 deletion cpp/tensorrt_llm/runtime/mcastDeviceMemory.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "tensorrt_llm/common/mcastDevMemUtils.h"
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <cstddef>
#include <cstdint>
#include <cuda.h>
Expand All @@ -42,7 +43,8 @@ class McastDeviceMemory
McastDeviceMemory(McastDeviceMemory const&) = delete;
McastDeviceMemory& operator=(McastDeviceMemory const&) = delete;

McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink);
McastDeviceMemory(
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink);

// We don't register the pointer in these two functions since we don't expect any python-level code would call
// to obtain the raw pointers.
Expand Down Expand Up @@ -98,6 +100,8 @@ class McastDeviceMemory
CUmemGenericAllocationHandle mMcHandle;
std::vector<CUmemGenericAllocationHandle> mUcHandles;

tensorrt_llm::mpi::MpiComm mGroupComm; //!< The MPI communicator for the group

// Host array of pointers
std::vector<CUdeviceptr> mUcPtrs;
std::vector<CUdeviceptr> mSignalPads;
Expand Down
22 changes: 15 additions & 7 deletions cpp/tensorrt_llm/runtime/mcastGPUBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ class McastGPUBuffer
//! \param bufSize The total size of the buffer in bytes.
//! \param groupSize The number of ranks in the communication group.
//! \param groupRank The rank of the local process within the group.
//! \param splitColor The color of the split for topology split.
//! \param device The CUDA device for buffer allocation.
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, at::Device device, bool mnNvlink)
: mMcastDeviceMemory(bufSize, groupSize, groupRank, device.index(), mnNvlink)
McastGPUBuffer(
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink)
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink)
, mBufSize(bufSize)
, mLocalDevice(device)
, mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx))
{
}

Expand All @@ -49,7 +51,7 @@ class McastGPUBuffer
//! \param dtype The data type of the tensor elements.
//! \param storageOffset The offset in elements from the start of the buffer.
//! \return An ATen tensor wrapping the unicast buffer section.
at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
at::Tensor getUCBuffer(uint32_t rank, std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
{
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
size_t const elementSize = c10::elementSize(dtype);
Expand All @@ -59,15 +61,18 @@ class McastGPUBuffer
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize;

auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
.options(options)
.target_device(mLocalDevice)
.make_tensor();
}

//! \brief Returns a PyTorch tensor view of the multicast buffer portion.
//! \param sizes The desired shape (dimensions) of the tensor.
//! \param dtype The data type of the tensor elements.
//! \param storageOffset The offset in elements from the start of the buffer.
//! \return An ATen tensor wrapping the multicast buffer section.
at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
at::Tensor getMCBuffer(std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
{
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
size_t const elementSize = c10::elementSize(dtype);
Expand All @@ -77,7 +82,10 @@ class McastGPUBuffer
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize;

auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
.options(options)
.target_device(mLocalDevice)
.make_tensor();
}

private:
Expand Down
23 changes: 11 additions & 12 deletions tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import math
import os
import platform
Expand All @@ -8,7 +7,7 @@
import torch
from torch import nn

from tensorrt_llm._utils import mpi_barrier
from tensorrt_llm._utils import mpi_comm
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy, MoEAllReduceParams)
Expand All @@ -17,7 +16,6 @@
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper

_thread_local = threading.local()
logger = logging.getLogger(__name__)


def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
Expand Down Expand Up @@ -55,11 +53,15 @@ def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None:
def get_allreduce_mnnvl_workspace(
mapping: Mapping, dtype: torch.dtype
) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]:

if not hasattr(_thread_local,
f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'):
setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}',
{})

# Support topology split
comm = mpi_comm().Split(
int(mapping.pp_rank * mapping.cp_size + mapping.cp_rank),
mapping.tp_rank)
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"

allreduce_mnnvl_workspaces = getattr(
Expand All @@ -77,7 +79,9 @@ def get_allreduce_mnnvl_workspace(
buffer_size_in_bytes,
mapping.tp_size,
mapping.tp_rank,
torch.device("cuda", mapping.local_rank),
# Split the communicator according to the topology
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
mapping.local_rank,
True, # mnNvlink
)

Expand All @@ -87,7 +91,7 @@ def get_allreduce_mnnvl_workspace(
buffer.fill_(-0.0)
# CPU barrier since we assume this should not be called in cuda graph
torch.cuda.synchronize()
mpi_barrier()
comm.Barrier()

# This is a buffer to maintain the state of this allreduce Op
# Should have the same lifetime with self._buffer
Expand Down Expand Up @@ -458,12 +462,7 @@ def __init__(self,
# Initialize MNNVL AllReduce if needed
if self.strategy in (AllReduceStrategy.AUTO,
AllReduceStrategy.MNNVL):
if self.mapping.tp_size != self.mapping.world_size:
logger.debug(
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
f"!= world_size:{self.mapping.world_size}")
self.mnnvl_allreduce = None
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
try:
self.mnnvl_allreduce = MNNVLAllReduce(
self.mapping, dtype) if dtype else None
Expand Down
5 changes: 2 additions & 3 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,12 +771,11 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
self.mapping.tp_size,
)

if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl(
):
if tp > self.mapping.gpus_per_node:
mlp_tp_size = math.gcd(
tp,
self.mapping.gpus_per_node,
) # Avoid costly inter-node TP when MNNVL is not supported
) # Avoid costly inter-node TP
else:
mlp_tp_size = tp
return mlp_tp_size
Expand Down