Skip to content

Commit c0907a5

Browse files
committed
add some comments
Signed-off-by: Chuang Zhu <[email protected]>
1 parent 399252d commit c0907a5

File tree

8 files changed

+60
-36
lines changed

8 files changed

+60
-36
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
7575
bool CacheFormatter::needSendCache(
7676
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
7777
{
78-
// int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
7978
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
8079
if (targetInfo.mDupHeadFactor <= 1)
8180
{
@@ -91,12 +90,17 @@ bool CacheFormatter::needSendCache(
9190
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
9291
}
9392

93+
// only TP rank % dupHeadFactor == 0 need to send cache.
9494
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0;
9595
}
9696

9797
void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
9898
BaseCacheFormatter::CacheState const& destConfig)
9999
{
100+
// TODO: VSWA do not support uneven layer per PP.
101+
// if gen PP and context PP are different, cache formatter only support alternative window like gpt-oss.
102+
// which is one layer is WSA, and another layer is Full attention.
103+
100104
auto numPools = cacheManager->getBlockManager().getNumPools();
101105
auto layerNum = cacheManager->getBlockManager().getNumLayers();
102106

@@ -163,6 +167,7 @@ void CacheFormatter::format(TransferSession& session)
163167
auto const& destConfig = session.getOtherState().getCacheState().value();
164168
auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx();
165169
auto& bufferManager = session.getBufferManager();
170+
// Some TP rank don't need to send cache since duplicate header is not needed.
166171
if (!needSendCache(selfConfig, destConfig, selfIdx))
167172
{
168173
return;
@@ -214,7 +219,7 @@ void CacheFormatter::format(TransferSession& session)
214219
int blockNum = 0;
215220

216221
size_t allCacheBlockSize = 0;
217-
222+
// gather cache blocks of the request.
218223
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocks;
219224
for (auto poolIdx = 0; poolIdx < numPools; poolIdx++)
220225
{
@@ -224,6 +229,7 @@ void CacheFormatter::format(TransferSession& session)
224229
"window size already exists, which is not supported");
225230
inputKvCacheBlocks.emplace(window, std::vector<runtime::ITensor::SharedPtr>());
226231
auto maxBlockThisWindow = window / selfConfig.getModelConfig().mTokensPerBlock;
232+
// only block in window will be sent.
227233
SizeType32 blockNumThisWindow = 0;
228234
for (auto it = blockRange.begin(); it != blockRange.end(); ++it)
229235
{
@@ -278,6 +284,14 @@ void CacheFormatter::format(TransferSession& session)
278284
return;
279285
}
280286

287+
// formatter flow
288+
// 1. gather cache blocks of the request.
289+
// 2. compute the buffer size for each target.
290+
// 3. prepare the pre-allocated buffer for each target according to the buffer size.
291+
// 4. call splitKVCacheDispatch to split the cache blocks according to the different parallelis and gather the
292+
// cache blocks to the corresponding buffer.
293+
// 5. send the buffer to the corresponding target. Ideally, we send only once (one buffer) for each target.
294+
281295
auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend();
282296
int peerDuplicateHeadFactor = targetInfo.mPeerDupHeadFactor;
283297
auto targetNum = connections.size();
@@ -286,7 +300,7 @@ void CacheFormatter::format(TransferSession& session)
286300
int selfAttentionLayerNum
287301
= selfConfig.getParallelConfig()
288302
.mAttentionLayerNumPerPP[selfIdx / selfConfig.getParallelConfig().mTensorParallelism];
289-
303+
// since layer num per pp rank maybe different, we need to compute the buffer size for each target.
290304
auto getBufferSizeForTarget = [&]()
291305
{
292306
std::vector<size_t> bufferSizeForTarget(targetNum, 0);
@@ -419,7 +433,7 @@ void CacheFormatter::format(TransferSession& session)
419433
}
420434
else
421435
{
422-
// concurrency num
436+
// concurrency num should <=bufferCoverTargetNum to avoid data-race.
423437
auto concurrencyNum
424438
= std::min(std::max(static_cast<size_t>(1), bufferCoverTargetNum), connections.size());
425439

@@ -505,6 +519,7 @@ void CacheFormatter::unformat(TransferSession& session)
505519
TLLM_CHECK(!outputBuffersPerWindow.empty());
506520
if (outputBuffersPerWindow.size() > 1)
507521
{
522+
// We only support limited case for VSWA.
508523
if (selfConfig.getParallelConfig().mPipelineParallelism != destConfig.getParallelConfig().mPipelineParallelism)
509524
{
510525
checkAlternateWindow(mCacheManager, selfConfig, destConfig);
@@ -603,6 +618,13 @@ void CacheFormatter::unformat(TransferSession& session)
603618
ctxReqId);
604619
return;
605620
}
621+
// unformatted flow
622+
// 1. gather cache blocks of the request.
623+
// 2. compute the buffer size for each target.
624+
// 3. prepare the pre-allocated buffer for each target according to the buffer size.
625+
// 4. receive the buffer from the corresponding target. Ideally, we receive only once (one buffer) for each
626+
// target.
627+
// 5. call concatKvCacheV2Dispatch to concatenate the cache blocks according to the different parallelis
606628

607629
runtime::ITensor::SharedPtr recvBufferTemp;
608630
std::vector<runtime::ITensor::SharedPtr> recvSplitCaches;
@@ -615,7 +637,7 @@ void CacheFormatter::unformat(TransferSession& session)
615637
int selfAttentionLayerNum
616638
= selfConfig.getParallelConfig()
617639
.mAttentionLayerNumPerPP[selfIdx / selfConfig.getParallelConfig().mTensorParallelism];
618-
auto getTargetBufferEleSzie = [&]()
640+
auto getTargetBufferEleSize = [&]()
619641
{
620642
if (outputBuffersPerWindow.size() > 1)
621643
{
@@ -627,14 +649,17 @@ void CacheFormatter::unformat(TransferSession& session)
627649
// TODO: LayerNumbufferTargetNum for VWSA
628650
return std::make_pair(bufferSizeForTarget, std::vector<SizeType32>(targetNum, 0));
629651
}
630-
size_t valideTpSize = pickUpConnections.size() / targetInfo.mDomainPPSize;
631-
TLLM_CHECK_WITH_INFO(cacheBlockSizeSum % valideTpSize == 0,
632-
"cacheBlockSizeSum must be divisible by valideTpSize %ld", valideTpSize);
633-
TLLM_CHECK_WITH_INFO((cacheBlockSizeSum % (selfAttentionLayerNum * valideTpSize)) == 0,
634-
"cacheBlockSizeSum must be divisible by valideTpSize %ld * selfAttentionLayerNum %d", valideTpSize,
652+
// for duplicate header, gen will not recv from TP which has duplicate header, and will not prepare
653+
// buffer for it.
654+
size_t validTpSize = pickUpConnections.size() / targetInfo.mDomainPPSize;
655+
TLLM_CHECK_WITH_INFO(cacheBlockSizeSum % validTpSize == 0,
656+
"cacheBlockSizeSum must be divisible by validTpSize %ld", validTpSize);
657+
TLLM_CHECK_WITH_INFO((cacheBlockSizeSum % (selfAttentionLayerNum * validTpSize)) == 0,
658+
"cacheBlockSizeSum must be divisible by validTpSize %ld * selfAttentionLayerNum %d", validTpSize,
635659
selfAttentionLayerNum);
636660
TLLM_CHECK(targetNum == pickUpConnections.size());
637-
size_t baseEleSize = cacheBlockSizeSum / (valideTpSize * selfAttentionLayerNum);
661+
// the sum of buffer size is cacheBlockSizeSum.
662+
size_t baseEleSize = cacheBlockSizeSum / (validTpSize * selfAttentionLayerNum);
638663

639664
std::vector<size_t> bufferEleSizes(targetNum, 0);
640665
std::vector<SizeType32> LayerNumbufferTargetNum(targetNum, 0);
@@ -647,7 +672,7 @@ void CacheFormatter::unformat(TransferSession& session)
647672
}
648673
return std::make_pair(bufferEleSizes, LayerNumbufferTargetNum);
649674
};
650-
auto [bufferEleSizes, LayerNumbufferTargetNum] = getTargetBufferEleSzie();
675+
auto [bufferEleSizes, LayerNumbufferTargetNum] = getTargetBufferEleSize();
651676

652677
size_t remainNoCoverTargetNum = 0;
653678
size_t bufferCoverTargetNum = 0;

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
369369
auto selfAttentionLayerNum
370370
= selfConfig.getParallelConfig()
371371
.mAttentionLayerNumPerPP[selfIdx / selfConfig.getParallelConfig().mTensorParallelism];
372+
TLLM_CHECK_WITH_INFO(selfAttentionLayerNum != 0, "selfAttentionLayerNum should not be 0");
372373
auto getBufferSizeForTarget = [&]()
373374
{
374375
std::vector<size_t> bufferEleSizes(targetNum, 0);

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,14 @@ std::string genUniqueAgentName()
3535
return std::string(hostname) + "_" + std::to_string(pid) + "_" + std::to_string(counter++);
3636
}
3737

38+
// NIXL connection is specific ,and different from the UCX and mpi connection, since NIXL only support one-sided
39+
// communication. gen send buffer metaData to context when it sending requestInfo, but don't send buffer offset, since
40+
// unformmatter has not called yet, it didn't know the cacheSize and offset. We assume the recv_size is the same as the
41+
// send_size. and compute the buffer offset according to the layer num of the selfPPrank ,and previous PP rank's layer
42+
// num, since the buffer size is ratio is equal to the layer num ratio except the VSWA case.
43+
3844
auto computeSendOffsetRatio(
39-
CacheState const& peerCacheState, size_t peerIdx, CacheState const& selfCacheState, int valideConnectionIdx)
45+
CacheState const& peerCacheState, int peerIdx, CacheState const& selfCacheState, int valideConnectionIdx)
4046
{
4147
auto peerTargetInfo = targetIRanks(selfCacheState, peerCacheState, peerIdx);
4248
// int ppRank = valideConnectionIdx % peerTargetInfo.mDomainPPSize;
@@ -352,7 +358,7 @@ batch_manager::kv_cache_manager::CacheTransBufferManager* AgentConnectionManager
352358
return mCacheTransBufferManager;
353359
}
354360

355-
AgentConnection* AgentConnectionManager::connect(std::string const& remoteAgentName, std::string const& connecitonInfo,
361+
AgentConnection* AgentConnectionManager::connect(std::string const& remoteAgentName, std::string const& connectionInfo,
356362
std::optional<std::string> metadata, bool isSender)
357363
{
358364

@@ -393,7 +399,7 @@ AgentConnection* AgentConnectionManager::connect(std::string const& remoteAgentN
393399
TLLM_CHECK_WITH_INFO(!isSender, "Sender shouldn't call connectRemoteAgent");
394400
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "mAgentName: %s connect to %s with connectRemoteAgent",
395401
mAgentName.c_str(), remoteAgentName.c_str());
396-
m_Agent->connectRemoteAgent(remoteAgentName, connecitonInfo);
402+
m_Agent->connectRemoteAgent(remoteAgentName, connectionInfo);
397403
}
398404
}
399405
else

cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ TargetRanksInfo TargetRanksInfoForDP(
6969
TLLM_CHECK(peerNumLayerPerPP.size() == peerPPNum);
7070
TLLM_CHECK(selfNumLayerPerPP.size() == selfPPNum);
7171
int selfStartLayerId = 0;
72+
// global start layer id for selfPPrank, which is the sum of the layer num of the previous PP ranks.
73+
// compute the target PP ranks and layer num need to be fetched from each target PP rank, according to [global start
74+
// layer id, global end layer id)
7275

7376
for (int pp_rank = 0; pp_rank < selfPPRank; pp_rank++)
7477
{
@@ -515,8 +518,6 @@ nvinfer1::Dims makeShapeFromCacheState(kv_cache::CacheState const& cacheState)
515518
cacheState.getAttentionConfig().mKvFactor, blockSize});
516519
}
517520

518-
// MLA Head 1: One thread block per [(2), tokens, dimsPerHead]
519-
520521
__device__ __forceinline__ void getLayerIdInDomainPPandRankInDomainPP(int layerId, int DomainPPSize,
521522
uint64_t* prefixLayerNumDevPtr, int& layerIdInDomainPP, int& rankInDomainPP, int& layerNumInSpecPP)
522523
{
@@ -542,6 +543,8 @@ __device__ __forceinline__ void getLayerIdInDomainPPandRankInDomainPP(int layerI
542543
layerNumInSpecPP = sharedLayerNumInSpecPP;
543544
}
544545

546+
// MLA Head 1: One thread block per [(2), tokens, dimsPerHead]
547+
545548
template <typename T, int subWarpSize, int vecSizeByte>
546549
__global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T** __restrict__ outputCaches,
547550
int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int inputBlockNum, int DomainPPSize,
@@ -638,19 +641,12 @@ __global__ void splitKVCacheKernel(T const** __restrict__ inputBlocks, T** __res
638641
for (int layerId = blockIdx.x; layerId < numLayers; layerId += gridDim.x)
639642
{
640643

641-
// if(peer PPrank ==threadIdx.x; peerPPRank <DomainPPSize)
642-
// if( layerId>xx[peeRank] &&layerId<xx[peerPPRank+1])
643-
// peerPPrank , layerIdInDomainPP = layerId - xx[peerPPrank]
644644
int layerIdInDomainPP{};
645645
int rankInDomainPP{};
646646
int layerNumInSpecPP{};
647647
getLayerIdInDomainPPandRankInDomainPP(
648648
layerId, DomainPPSize, prefixLayerNumDevPtr, layerIdInDomainPP, rankInDomainPP, layerNumInSpecPP);
649649

650-
// if (threadIdx.x == 0){
651-
// printf("splitKVCacheKernel: layerId:%d, layerIdInDomainPP:%d, rankInDomainPP:%d,
652-
// layerNumInSpecPP:%d\n", layerId, layerIdInDomainPP, rankInDomainPP, layerNumInSpecPP);
653-
// }
654650
#pragma unroll 1
655651

656652
for (int headId = subWarpGroupId; headId < headNum; headId += subWarpGroupNum)
@@ -893,11 +889,6 @@ __global__ void concatKVCacheKernel(T const** __restrict__ inputCaches, T** __re
893889
getLayerIdInDomainPPandRankInDomainPP(
894890
layerId, DomainPPSize, prefixLayerNumDevPtr, layerIdInDomainPP, rankInDomainPP, layerNumInSpecPP);
895891

896-
// if (threadIdx.x == 0){
897-
// printf("concatKVCacheKernel: layerId:%d, layerIdInDomainPP:%d, rankInDomainPP:%d,
898-
// layerNumInSpecPP:%d\n", layerId, layerIdInDomainPP, rankInDomainPP, layerNumInSpecPP);
899-
// }
900-
901892
#pragma unroll 1
902893
for (int headId = subWarpGroupId; headId < headNum; headId += subWarpGroupNum)
903894
{

cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <nanobind/stl/optional.h>
2626
#include <nanobind/stl/shared_ptr.h>
2727
#include <nanobind/stl/unique_ptr.h>
28+
#include <nanobind/stl/vector.h>
2829
#include <nanobind/trampoline.h>
2930
#include <torch/extension.h>
3031

cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ TEST_F(RequestInfoTest, Basic)
9999
}
100100
auto state = std::make_unique<texec::DataTransceiverState>();
101101
state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"});
102-
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, {4}, nvinfer1::DataType::kFLOAT});
102+
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, {10}, nvinfer1::DataType::kFLOAT});
103103
RequestInfo info{1, *state};
104104
auto info2 = serializeDeserialize(info);
105105
EXPECT_EQ(info, info2);
@@ -167,7 +167,7 @@ class MockDataSender : public DataSender
167167
ON_CALL(*this, recvRequestInfo)
168168
.WillByDefault(Return(RequestInfo{0,
169169
texec::DataTransceiverState{
170-
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {4}, nvinfer1::DataType::kFLOAT},
170+
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {10}, nvinfer1::DataType::kFLOAT},
171171
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
172172
ON_CALL(*this, getCounterpartsCount).WillByDefault(Return(1));
173173
}
@@ -983,8 +983,6 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
983983
startLayerId += mAttentionLayerNumPerPP[ppRank];
984984
}
985985
}
986-
// TLLM_LOG_INFO(tensorrt_llm::mpi::MpiComm::world().getRank(), " fillBlockData startLayerId:%d
987-
// layerSizethisRank:%d", startLayerId, layerSizeThisRank);
988986
int headSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.at(0);
989987
int startHeadId = headSizePerRank * (mTpRank / mDupHeadFactor);
990988
bool enableDP = mCacheState->getParallelConfig().mEnableAttentionDP;
@@ -1061,8 +1059,6 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
10611059
}
10621060
}
10631061

1064-
// TLLM_LOG_INFO(tensorrt_llm::mpi::MpiComm::world().getRank(), " verifyBlockData startLayerId:%d
1065-
// layerSizethisRank:%d", startLayerId, layerSizethisRank);
10661062
int headSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.at(0);
10671063
int startHeadId = headSizePerRank * (mTpRank / mDupHeadFactor);
10681064
bool enableDP = mCacheState->getParallelConfig().mEnableAttentionDP;

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(self, mapping: Mapping, dist: Distributed,
102102
head_dim = kv_cache_manager.head_dim
103103
tokens_per_block = kv_cache_manager.tokens_per_block
104104
dtype = kv_cache_manager.dtype
105+
# get the layer num per pp rank, which is required by cache transceiver.
105106
pp_layer_num = len(kv_cache_manager.pp_layers)
106107
pp_layer_num_per_pp_rank = dist.pp_allgather(pp_layer_num)
107108
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,8 @@ def test_disaggregated_ctxpp4_genpp4(disaggregated_test_root, llm_venv,
716716
cwd=llm_venv.get_working_directory())
717717

718718

719-
@pytest.mark.skip_less_device(4)
719+
#tiny llama pp4 will have uneven layer per pp. pp4
720+
@pytest.mark.skip_less_device(8)
720721
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
721722
indirect=True)
722723
def test_disaggregated_ctxpp4_gentp4(disaggregated_test_root, llm_venv,
@@ -802,12 +803,14 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp(
802803
cwd=llm_venv.get_working_directory())
803804

804805

806+
@pytest.mark.skip_less_device(4)
805807
@skip_no_hopper
806808
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
807809
indirect=True)
808810
def test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp(
809811
disaggregated_test_root, disaggregated_example_root, llm_venv,
810812
deepseek_v3_model_root):
813+
#add one mtp layer, pp rank0 will have 15 layer, pp rank 1 will have 16 layers.
811814
src_dst_dict = {
812815
deepseek_v3_model_root:
813816
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/fp8",

0 commit comments

Comments
 (0)