Skip to content

Commit ad0f413

Browse files
authored
Merge branch 'release/1.1.0rc2' into supportFP8BlockWideEp_release
2 parents e68c7bb + 9d6e87a commit ad0f413

File tree

13 files changed

+139
-43
lines changed

13 files changed

+139
-43
lines changed

cpp/tensorrt_llm/nanobind/runtime/bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ void initBindings(nb::module_& m)
340340
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations");
341341

342342
nb::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
343-
.def(nb::init<size_t, uint32_t, uint32_t, at::Device, bool>())
343+
.def(nb::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), nb::arg("buf_size"),
344+
nb::arg("group_size"), nb::arg("group_rank"), nb::arg("split_color"), nb::arg("device_idx"),
345+
nb::arg("mn_nvlink"))
344346
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer)
345347
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer);
346348

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,9 @@ void initBindings(pybind11::module_& m)
434434
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations");
435435

436436
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
437-
.def(py::init<size_t, uint32_t, uint32_t, at::Device, bool>())
437+
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, uint32_t, bool>(), py::arg("buf_size"),
438+
py::arg("group_size"), py::arg("group_rank"), py::arg("split_color"), py::arg("device_idx"),
439+
py::arg("mn_nvlink"))
438440
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer)
439441
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer);
440442

cpp/tensorrt_llm/runtime/mcastDeviceMemory.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "tensorrt_llm/common/cudaDriverWrapper.h"
2121
#include "tensorrt_llm/common/cudaUtils.h"
2222
#include "tensorrt_llm/common/logger.h"
23-
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
23+
2424
#include <cstddef>
2525
#include <cstdint>
2626
#include <cuda_runtime_api.h>
@@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran)
3838
} // namespace
3939

4040
McastDeviceMemory::McastDeviceMemory(
41-
size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink)
41+
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink)
4242
: mIsMNNvlink(mnNvlink)
4343
, mDeviceIdx(deviceIdx)
4444
, mGroupSize(groupSize)
@@ -48,6 +48,7 @@ McastDeviceMemory::McastDeviceMemory(
4848
, mAllocationSize(0)
4949
, mMcPtr(0)
5050
, mMcHandle(0)
51+
, mGroupComm(tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank))
5152
{
5253

5354
TLLM_CUDA_CHECK(cudaSetDevice(mDeviceIdx));
@@ -62,9 +63,12 @@ McastDeviceMemory::McastDeviceMemory(
6263
// From pytorch implementation for alignment
6364
constexpr size_t kSignalPadAlignment = 16UL;
6465
mSignalPadOffset = roundUp(mBufSize, kSignalPadAlignment);
66+
int const world_rank{tensorrt_llm::mpi::MpiComm::session().getRank()};
67+
6568
TLLM_LOG_DEBUG(
66-
"[McastDeviceMemory] Rank: %u, Group size: %u, isMultiNode: %d, device_idx: %d, Signal pad offset: %zu",
67-
mGroupRank, mGroupSize, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
69+
"[McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, "
70+
"device_idx: %d, Signal pad offset: %zu",
71+
world_rank, mGroupRank, mGroupSize, splitColor, mIsMNNvlink, mDeviceIdx, mSignalPadOffset);
6872

6973
if (mIsMNNvlink)
7074
{
@@ -127,9 +131,6 @@ McastDeviceMemory::~McastDeviceMemory()
127131

128132
void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
129133
{
130-
131-
auto const& mpi_comm = tensorrt_llm::mpi::MpiComm::session();
132-
133134
CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC;
134135
CUmemAllocationProp prop = {};
135136
prop.requestedHandleTypes = handle_type;
@@ -156,7 +157,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
156157
// All gather
157158
cudaMallocHost(&exphndl, mGroupSize * sizeof(CUmemFabricHandle));
158159
memcpy(exphndl + mGroupRank * sizeof(CUmemFabricHandle), &myhndl, sizeof(CUmemFabricHandle));
159-
mpi_comm.allgather(
160+
mGroupComm.allgather(
160161
exphndl + mGroupRank * sizeof(CUmemFabricHandle), exphndl, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR);
161162
cudaDeviceSynchronize();
162163

@@ -175,7 +176,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
175176
TLLM_CU_CHECK(cuMemExportToShareableHandle((void*) fabric_handle, mMcHandle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
176177
}
177178
// Broadcast
178-
mpi_comm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
179+
mGroupComm.bcast(fabric_handle, sizeof(CUmemFabricHandle), mpi::MpiType::kCHAR, 0);
179180
cudaDeviceSynchronize();
180181
if (mGroupRank != 0)
181182
{
@@ -210,12 +211,9 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
210211

211212
void McastDeviceMemory::allocNvlsMcastMem(size_t bufSize)
212213
{
213-
// Create a std::set to include all ranks in range (0, group_size)
214-
std::set<int> ranks;
215-
for (uint32_t i = 0; i < mGroupSize; ++i)
216-
{
217-
ranks.insert(i);
218-
}
214+
// Get the world ranks for ranks in this group
215+
auto ranks_ = tensorrt_llm::mpi::getWorldRanks(mGroupComm);
216+
std::set<int> ranks(ranks_.begin(), ranks_.end());
219217
// Reuse existing implementation
220218
mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate(bufSize, ranks);
221219
mMcHandle = mNvlsHandle->mc_handle;

cpp/tensorrt_llm/runtime/mcastDeviceMemory.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "tensorrt_llm/common/mcastDevMemUtils.h"
1919
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
20+
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
2021
#include <cstddef>
2122
#include <cstdint>
2223
#include <cuda.h>
@@ -42,7 +43,8 @@ class McastDeviceMemory
4243
McastDeviceMemory(McastDeviceMemory const&) = delete;
4344
McastDeviceMemory& operator=(McastDeviceMemory const&) = delete;
4445

45-
McastDeviceMemory(size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink);
46+
McastDeviceMemory(
47+
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink);
4648

4749
// We don't register the pointer in these two functions since we don't expect any python-level code would call
4850
// to obtain the raw pointers.
@@ -98,6 +100,8 @@ class McastDeviceMemory
98100
CUmemGenericAllocationHandle mMcHandle;
99101
std::vector<CUmemGenericAllocationHandle> mUcHandles;
100102

103+
tensorrt_llm::mpi::MpiComm mGroupComm; //!< The MPI communicator for the group
104+
101105
// Host array of pointers
102106
std::vector<CUdeviceptr> mUcPtrs;
103107
std::vector<CUdeviceptr> mSignalPads;

cpp/tensorrt_llm/runtime/mcastGPUBuffer.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ class McastGPUBuffer
3434
//! \param bufSize The total size of the buffer in bytes.
3535
//! \param groupSize The number of ranks in the communication group.
3636
//! \param groupRank The rank of the local process within the group.
37+
//! \param splitColor The color of the split for topology split.
3738
//! \param device The CUDA device for buffer allocation.
3839
//! \param mnNvlink Flag indicating if multi-node NVLink is used.
39-
McastGPUBuffer(size_t bufSize, uint32_t groupSize, uint32_t groupRank, at::Device device, bool mnNvlink)
40-
: mMcastDeviceMemory(bufSize, groupSize, groupRank, device.index(), mnNvlink)
40+
McastGPUBuffer(
41+
size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, uint32_t deviceIdx, bool mnNvlink)
42+
: mMcastDeviceMemory(bufSize, groupSize, groupRank, splitColor, deviceIdx, mnNvlink)
4143
, mBufSize(bufSize)
42-
, mLocalDevice(device)
44+
, mLocalDevice(at::Device(at::DeviceType::CUDA, deviceIdx))
4345
{
4446
}
4547

@@ -49,7 +51,7 @@ class McastGPUBuffer
4951
//! \param dtype The data type of the tensor elements.
5052
//! \param storageOffset The offset in elements from the start of the buffer.
5153
//! \return An ATen tensor wrapping the unicast buffer section.
52-
at::Tensor getUCBuffer(uint32_t rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
54+
at::Tensor getUCBuffer(uint32_t rank, std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
5355
{
5456
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
5557
size_t const elementSize = c10::elementSize(dtype);
@@ -59,15 +61,18 @@ class McastGPUBuffer
5961
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getUnicastPtr(rank)) + storageOffset * elementSize;
6062

6163
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
62-
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
64+
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
65+
.options(options)
66+
.target_device(mLocalDevice)
67+
.make_tensor();
6368
}
6469

6570
//! \brief Returns a PyTorch tensor view of the multicast buffer portion.
6671
//! \param sizes The desired shape (dimensions) of the tensor.
6772
//! \param dtype The data type of the tensor elements.
6873
//! \param storageOffset The offset in elements from the start of the buffer.
6974
//! \return An ATen tensor wrapping the multicast buffer section.
70-
at::Tensor getMCBuffer(c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storageOffset)
75+
at::Tensor getMCBuffer(std::vector<long int> sizes, torch::ScalarType dtype, int64_t storageOffset)
7176
{
7277
size_t const numel = std::accumulate(sizes.begin(), sizes.end(), 1UL, std::multiplies<size_t>());
7378
size_t const elementSize = c10::elementSize(dtype);
@@ -77,7 +82,10 @@ class McastGPUBuffer
7782
auto* dataPtr = static_cast<uint8_t*>(mMcastDeviceMemory.getMulticastPtr()) + storageOffset * elementSize;
7883

7984
auto options = at::TensorOptions().dtype(dtype).device(mLocalDevice);
80-
return at::for_blob(dataPtr, sizes).options(options).target_device(mLocalDevice).make_tensor();
85+
return at::for_blob(dataPtr, c10::IntArrayRef(sizes))
86+
.options(options)
87+
.target_device(mLocalDevice)
88+
.make_tensor();
8189
}
8290

8391
private:

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import math
32
import os
43
import platform
@@ -8,7 +7,7 @@
87
import torch
98
from torch import nn
109

11-
from tensorrt_llm._utils import mpi_barrier
10+
from tensorrt_llm._utils import mpi_comm
1211
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
1312
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
1413
AllReduceStrategy, MoEAllReduceParams)
@@ -17,7 +16,6 @@
1716
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
1817

1918
_thread_local = threading.local()
20-
logger = logging.getLogger(__name__)
2119

2220

2321
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
@@ -55,11 +53,15 @@ def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None:
5553
def get_allreduce_mnnvl_workspace(
5654
mapping: Mapping, dtype: torch.dtype
5755
) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]:
56+
5857
if not hasattr(_thread_local,
5958
f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'):
6059
setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}',
6160
{})
62-
61+
# Support topology split
62+
comm = mpi_comm().Split(
63+
int(mapping.pp_rank * mapping.cp_size + mapping.cp_rank),
64+
mapping.tp_rank)
6365
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"
6466

6567
allreduce_mnnvl_workspaces = getattr(
@@ -77,7 +79,9 @@ def get_allreduce_mnnvl_workspace(
7779
buffer_size_in_bytes,
7880
mapping.tp_size,
7981
mapping.tp_rank,
80-
torch.device("cuda", mapping.local_rank),
82+
# Split the communicator according to the topology
83+
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
84+
mapping.local_rank,
8185
True, # mnNvlink
8286
)
8387

@@ -87,7 +91,7 @@ def get_allreduce_mnnvl_workspace(
8791
buffer.fill_(-0.0)
8892
# CPU barrier since we assume this should not be called in cuda graph
8993
torch.cuda.synchronize()
90-
mpi_barrier()
94+
comm.Barrier()
9195

9296
# This is a buffer to maintain the state of this allreduce Op
9397
# Should have the same lifetime with self._buffer
@@ -458,12 +462,7 @@ def __init__(self,
458462
# Initialize MNNVL AllReduce if needed
459463
if self.strategy in (AllReduceStrategy.AUTO,
460464
AllReduceStrategy.MNNVL):
461-
if self.mapping.tp_size != self.mapping.world_size:
462-
logger.debug(
463-
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
464-
f"!= world_size:{self.mapping.world_size}")
465-
self.mnnvl_allreduce = None
466-
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
465+
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
467466
try:
468467
self.mnnvl_allreduce = MNNVLAllReduce(
469468
self.mapping, dtype) if dtype else None

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,12 +761,11 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
761761
self.mapping.tp_size,
762762
)
763763

764-
if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl(
765-
):
764+
if tp > self.mapping.gpus_per_node:
766765
mlp_tp_size = math.gcd(
767766
tp,
768767
self.mapping.gpus_per_node,
769-
) # Avoid costly inter-node TP when MNNVL is not supported
768+
) # Avoid costly inter-node TP
770769
else:
771770
mlp_tp_size = tp
772771
return mlp_tp_size

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,11 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
136136
def needs_capture(self, batch_size: int):
137137
return (batch_size, self.draft_len) not in self.graph_outputs
138138

139-
def capture(self, batch_size: int, forward_fn: Callable,
140-
initial_inputs: Dict[str, Any]):
139+
def capture(self,
140+
batch_size: int,
141+
forward_fn: Callable,
142+
initial_inputs: Dict[str, Any],
143+
postprocess_fn: Optional[Callable] = None):
141144
"""Captures the forward pass for a given batch size."""
142145
engine = self._get_engine()
143146
key = (batch_size, self.draft_len)
@@ -181,8 +184,12 @@ def capture(self, batch_size: int, forward_fn: Callable,
181184
with with_multi_stream(True), piecewise_cuda_graph(False):
182185
for _ in range(self.WARMUP_STEPS):
183186
forward_fn(capture_inputs)
187+
if postprocess_fn is not None:
188+
postprocess_fn(capture_inputs)
184189
with torch.cuda.graph(graph, pool=self.memory_pool):
185190
output = forward_fn(capture_inputs)
191+
if postprocess_fn is not None:
192+
postprocess_fn(capture_inputs)
186193

187194
self.graphs[key] = graph
188195
self.graph_outputs[key] = make_weak_ref(output)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,26 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
11331133
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
11341134
return inputs
11351135

1136+
def _postprocess_inputs(self, inputs: Dict[str, Any]):
1137+
"""
1138+
Postprocess to make sure model forward doesn't change the inputs.
1139+
It is only used in cuda graph capture, because other cases will prepare
1140+
new inputs before the model forward.
1141+
"""
1142+
if self.enable_spec_decode and not self._disable_overlap_scheduler:
1143+
if inputs['attn_metadata'].kv_cache_manager is not None:
1144+
num_seqs = inputs['attn_metadata'].num_seqs
1145+
num_ctx_requests = inputs['attn_metadata'].num_contexts
1146+
num_gen_requests = inputs['attn_metadata'].num_generations
1147+
num_ctx_tokens = inputs['attn_metadata'].num_ctx_tokens
1148+
previous_batch_tokens = inputs['input_ids'].shape[
1149+
0] - num_ctx_tokens
1150+
inputs['position_ids'][0, num_ctx_tokens:] -= (
1151+
self.previous_pos_id_offsets_cuda[:previous_batch_tokens])
1152+
inputs['attn_metadata'].kv_lens_cuda[
1153+
num_ctx_requests:num_seqs] -= (
1154+
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
1155+
11361156
def _get_all_rank_num_tokens(self, attn_metadata: AttentionMetadata):
11371157
if self.enable_attention_dp:
11381158
return list(self.dist.tp_allgather(attn_metadata.num_tokens))
@@ -2206,8 +2226,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
22062226
gather_ids=gather_ids,
22072227
gather_context_logits=gather_context_logits)
22082228

2229+
def capture_postprocess_fn(inputs: Dict[str, Any]):
2230+
self._postprocess_inputs(inputs)
2231+
22092232
self.cuda_graph_runner.capture(batch_size,
2210-
capture_forward_fn, inputs)
2233+
capture_forward_fn, inputs,
2234+
capture_postprocess_fn)
22112235

22122236
# here we don't need to use context since cuda graph capture didn't run kernel.
22132237
# maybe we need a cleaner way to do this.

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1871,6 +1871,42 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
18711871
# task.evaluate(llm,
18721872
# extra_evaluator_kwargs=dict(apply_chat_template=True))
18731873

1874+
def test_nvfp4_multi_gpus_corner_case(self):
1875+
"""
1876+
This test is used to test the corner case of the NVFP4 model.
1877+
When using the same value for max_seq_len and max_num_tokens, there will be no
1878+
enough kv block for the dummy requests in CUDA graph warmup when creating
1879+
the py_executor before estimating kv cache. Then CUDA graph capture will be
1880+
triggered when estimating kv cache. This may cause some errors.
1881+
More info in https://nvbugs/5485325.
1882+
"""
1883+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80,
1884+
dtype="fp8",
1885+
enable_block_reuse=False)
1886+
pytorch_config = dict(disable_overlap_scheduler=False,
1887+
cuda_graph_config=CudaGraphConfig(
1888+
enable_padding=True, max_batch_size=1024),
1889+
moe_config=MoeConfig(backend="TRTLLM"))
1890+
1891+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=1)
1892+
with LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4",
1893+
tensor_parallel_size=8,
1894+
pipeline_parallel_size=1,
1895+
moe_expert_parallel_size=8,
1896+
kv_cache_config=kv_cache_config,
1897+
**pytorch_config,
1898+
enable_attention_dp=False,
1899+
speculative_config=mtp_config,
1900+
max_seq_len=5120,
1901+
max_num_tokens=5120) as llm:
1902+
1903+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
1904+
1905+
task = MMLU(self.MODEL_NAME)
1906+
task.evaluate(llm)
1907+
task = GSM8K(self.MODEL_NAME)
1908+
task.evaluate(llm)
1909+
18741910
@pytest.mark.skip_less_mpi_world_size(8)
18751911
@skip_pre_hopper
18761912
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)