20
20
#include " tensorrt_llm/common/cudaDriverWrapper.h"
21
21
#include " tensorrt_llm/common/cudaUtils.h"
22
22
#include " tensorrt_llm/common/logger.h"
23
- # include " tensorrt_llm/runtime/utils/mpiUtils.h "
23
+
24
24
#include < cstddef>
25
25
#include < cstdint>
26
26
#include < cuda_runtime_api.h>
@@ -38,7 +38,7 @@ inline size_t roundUp(size_t val, size_t gran)
38
38
} // namespace
39
39
40
40
McastDeviceMemory::McastDeviceMemory (
41
- size_t bufSize, uint32_t groupSize, uint32_t groupRank, int deviceIdx, bool mnNvlink)
41
+ size_t bufSize, uint32_t groupSize, uint32_t groupRank, uint32_t splitColor, int deviceIdx, bool mnNvlink)
42
42
: mIsMNNvlink (mnNvlink)
43
43
, mDeviceIdx (deviceIdx)
44
44
, mGroupSize (groupSize)
@@ -48,6 +48,7 @@ McastDeviceMemory::McastDeviceMemory(
48
48
, mAllocationSize (0 )
49
49
, mMcPtr (0 )
50
50
, mMcHandle (0 )
51
+ , mGroupComm (tensorrt_llm::mpi::MpiComm::session().split(splitColor, mGroupRank ))
51
52
{
52
53
53
54
TLLM_CUDA_CHECK (cudaSetDevice (mDeviceIdx ));
@@ -62,9 +63,12 @@ McastDeviceMemory::McastDeviceMemory(
62
63
// From pytorch implementation for alignment
63
64
constexpr size_t kSignalPadAlignment = 16UL ;
64
65
mSignalPadOffset = roundUp (mBufSize , kSignalPadAlignment );
66
+ int const world_rank{tensorrt_llm::mpi::MpiComm::session ().getRank ()};
67
+
65
68
TLLM_LOG_DEBUG (
66
- " [McastDeviceMemory] Rank: %u, Group size: %u, isMultiNode: %d, device_idx: %d, Signal pad offset: %zu" ,
67
- mGroupRank , mGroupSize , mIsMNNvlink , mDeviceIdx , mSignalPadOffset );
69
+ " [McastDeviceMemory] World Rank: %u, Group Rank: %u, Group size: %u, GroupSplitColor: %u, isMultiNode: %d, "
70
+ " device_idx: %d, Signal pad offset: %zu" ,
71
+ world_rank, mGroupRank , mGroupSize , splitColor, mIsMNNvlink , mDeviceIdx , mSignalPadOffset );
68
72
69
73
if (mIsMNNvlink )
70
74
{
@@ -127,9 +131,6 @@ McastDeviceMemory::~McastDeviceMemory()
127
131
128
132
void McastDeviceMemory::allocMnMcastMem (size_t bufSize)
129
133
{
130
-
131
- auto const & mpi_comm = tensorrt_llm::mpi::MpiComm::session ();
132
-
133
134
CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC;
134
135
CUmemAllocationProp prop = {};
135
136
prop.requestedHandleTypes = handle_type;
@@ -156,7 +157,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
156
157
// All gather
157
158
cudaMallocHost (&exphndl, mGroupSize * sizeof (CUmemFabricHandle));
158
159
memcpy (exphndl + mGroupRank * sizeof (CUmemFabricHandle), &myhndl, sizeof (CUmemFabricHandle));
159
- mpi_comm .allgather (
160
+ mGroupComm .allgather (
160
161
exphndl + mGroupRank * sizeof (CUmemFabricHandle), exphndl, sizeof (CUmemFabricHandle), mpi::MpiType::kCHAR );
161
162
cudaDeviceSynchronize ();
162
163
@@ -175,7 +176,7 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
175
176
TLLM_CU_CHECK (cuMemExportToShareableHandle ((void *) fabric_handle, mMcHandle , CU_MEM_HANDLE_TYPE_FABRIC, 0 ));
176
177
}
177
178
// Broadcast
178
- mpi_comm .bcast (fabric_handle, sizeof (CUmemFabricHandle), mpi::MpiType::kCHAR , 0 );
179
+ mGroupComm .bcast (fabric_handle, sizeof (CUmemFabricHandle), mpi::MpiType::kCHAR , 0 );
179
180
cudaDeviceSynchronize ();
180
181
if (mGroupRank != 0 )
181
182
{
@@ -210,12 +211,9 @@ void McastDeviceMemory::allocMnMcastMem(size_t bufSize)
210
211
211
212
void McastDeviceMemory::allocNvlsMcastMem (size_t bufSize)
212
213
{
213
- // Create a std::set to include all ranks in range (0, group_size)
214
- std::set<int > ranks;
215
- for (uint32_t i = 0 ; i < mGroupSize ; ++i)
216
- {
217
- ranks.insert (i);
218
- }
214
+ // Get the world ranks for ranks in this group
215
+ auto ranks_ = tensorrt_llm::mpi::getWorldRanks (mGroupComm );
216
+ std::set<int > ranks (ranks_.begin (), ranks_.end ());
219
217
// Reuse existing implementation
220
218
mNvlsHandle = tensorrt_llm::runtime::ipcNvlsAllocate (bufSize, ranks);
221
219
mMcHandle = mNvlsHandle ->mc_handle ;
0 commit comments