@@ -75,7 +75,6 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
75
75
bool CacheFormatter::needSendCache (
76
76
CacheState const & selfConfig, CacheState const & destConfig, runtime::SizeType32 selfIdx)
77
77
{
78
- // int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
79
78
auto targetInfo = executor::kv_cache::targetIRanks (destConfig, selfConfig, selfIdx);
80
79
if (targetInfo.mDupHeadFactor <= 1 )
81
80
{
@@ -91,12 +90,17 @@ bool CacheFormatter::needSendCache(
91
90
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
92
91
}
93
92
93
+ // only TP rank % dupHeadFactor == 0 need to send cache.
94
94
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0 ;
95
95
}
96
96
97
97
void checkAlternateWindow (BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const & selfConfig,
98
98
BaseCacheFormatter::CacheState const & destConfig)
99
99
{
100
+ // TODO: VSWA do not support uneven layer per PP.
101
+ // if gen PP and context PP are different, cache formatter only support alternative window like gpt-oss.
102
+ // which is one layer is WSA, and another layer is Full attention.
103
+
100
104
auto numPools = cacheManager->getBlockManager ().getNumPools ();
101
105
auto layerNum = cacheManager->getBlockManager ().getNumLayers ();
102
106
@@ -163,6 +167,7 @@ void CacheFormatter::format(TransferSession& session)
163
167
auto const & destConfig = session.getOtherState ().getCacheState ().value ();
164
168
auto const selfIdx = session.getSelfState ().getCommState ().value ().getSelfIdx ();
165
169
auto & bufferManager = session.getBufferManager ();
170
+ // Some TP rank don't need to send cache since duplicate header is not needed.
166
171
if (!needSendCache (selfConfig, destConfig, selfIdx))
167
172
{
168
173
return ;
@@ -214,7 +219,7 @@ void CacheFormatter::format(TransferSession& session)
214
219
int blockNum = 0 ;
215
220
216
221
size_t allCacheBlockSize = 0 ;
217
-
222
+ // gather cache blocks of the request.
218
223
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocks;
219
224
for (auto poolIdx = 0 ; poolIdx < numPools; poolIdx++)
220
225
{
@@ -224,6 +229,7 @@ void CacheFormatter::format(TransferSession& session)
224
229
" window size already exists, which is not supported" );
225
230
inputKvCacheBlocks.emplace (window, std::vector<runtime::ITensor::SharedPtr>());
226
231
auto maxBlockThisWindow = window / selfConfig.getModelConfig ().mTokensPerBlock ;
232
+ // only block in window will be sent.
227
233
SizeType32 blockNumThisWindow = 0 ;
228
234
for (auto it = blockRange.begin (); it != blockRange.end (); ++it)
229
235
{
@@ -278,6 +284,14 @@ void CacheFormatter::format(TransferSession& session)
278
284
return ;
279
285
}
280
286
287
+ // formatter flow
288
+ // 1. gather cache blocks of the request.
289
+ // 2. compute the buffer size for each target.
290
+ // 3. prepare the pre-allocated buffer for each target according to the buffer size.
291
+ // 4. call splitKVCacheDispatch to split the cache blocks according to the different parallelis and gather the
292
+ // cache blocks to the corresponding buffer.
293
+ // 5. send the buffer to the corresponding target. Ideally, we send only once (one buffer) for each target.
294
+
281
295
auto cacheBufferId = mCacheTransBufferManager ->assignBufferIndexForSend ();
282
296
int peerDuplicateHeadFactor = targetInfo.mPeerDupHeadFactor ;
283
297
auto targetNum = connections.size ();
@@ -286,7 +300,7 @@ void CacheFormatter::format(TransferSession& session)
286
300
int selfAttentionLayerNum
287
301
= selfConfig.getParallelConfig ()
288
302
.mAttentionLayerNumPerPP [selfIdx / selfConfig.getParallelConfig ().mTensorParallelism ];
289
-
303
+ // since layer num per pp rank maybe different, we need to compute the buffer size for each target.
290
304
auto getBufferSizeForTarget = [&]()
291
305
{
292
306
std::vector<size_t > bufferSizeForTarget (targetNum, 0 );
@@ -419,7 +433,7 @@ void CacheFormatter::format(TransferSession& session)
419
433
}
420
434
else
421
435
{
422
- // concurrency num
436
+ // concurrency num should <=bufferCoverTargetNum to avoid data-race.
423
437
auto concurrencyNum
424
438
= std::min (std::max (static_cast <size_t >(1 ), bufferCoverTargetNum), connections.size ());
425
439
@@ -505,6 +519,7 @@ void CacheFormatter::unformat(TransferSession& session)
505
519
TLLM_CHECK (!outputBuffersPerWindow.empty ());
506
520
if (outputBuffersPerWindow.size () > 1 )
507
521
{
522
+ // We only support limited case for VSWA.
508
523
if (selfConfig.getParallelConfig ().mPipelineParallelism != destConfig.getParallelConfig ().mPipelineParallelism )
509
524
{
510
525
checkAlternateWindow (mCacheManager , selfConfig, destConfig);
@@ -603,6 +618,13 @@ void CacheFormatter::unformat(TransferSession& session)
603
618
ctxReqId);
604
619
return ;
605
620
}
621
+ // unformatted flow
622
+ // 1. gather cache blocks of the request.
623
+ // 2. compute the buffer size for each target.
624
+ // 3. prepare the pre-allocated buffer for each target according to the buffer size.
625
+ // 4. receive the buffer from the corresponding target. Ideally, we receive only once (one buffer) for each
626
+ // target.
627
+ // 5. call concatKvCacheV2Dispatch to concatenate the cache blocks according to the different parallelis
606
628
607
629
runtime::ITensor::SharedPtr recvBufferTemp;
608
630
std::vector<runtime::ITensor::SharedPtr> recvSplitCaches;
@@ -615,7 +637,7 @@ void CacheFormatter::unformat(TransferSession& session)
615
637
int selfAttentionLayerNum
616
638
= selfConfig.getParallelConfig ()
617
639
.mAttentionLayerNumPerPP [selfIdx / selfConfig.getParallelConfig ().mTensorParallelism ];
618
- auto getTargetBufferEleSzie = [&]()
640
+ auto getTargetBufferEleSize = [&]()
619
641
{
620
642
if (outputBuffersPerWindow.size () > 1 )
621
643
{
@@ -627,14 +649,17 @@ void CacheFormatter::unformat(TransferSession& session)
627
649
// TODO: LayerNumbufferTargetNum for VWSA
628
650
return std::make_pair (bufferSizeForTarget, std::vector<SizeType32>(targetNum, 0 ));
629
651
}
630
- size_t valideTpSize = pickUpConnections.size () / targetInfo.mDomainPPSize ;
631
- TLLM_CHECK_WITH_INFO (cacheBlockSizeSum % valideTpSize == 0 ,
632
- " cacheBlockSizeSum must be divisible by valideTpSize %ld" , valideTpSize);
633
- TLLM_CHECK_WITH_INFO ((cacheBlockSizeSum % (selfAttentionLayerNum * valideTpSize)) == 0 ,
634
- " cacheBlockSizeSum must be divisible by valideTpSize %ld * selfAttentionLayerNum %d" , valideTpSize,
652
+ // for duplicate header, gen will not recv from TP which has duplicate header, and will not prepare
653
+ // buffer for it.
654
+ size_t validTpSize = pickUpConnections.size () / targetInfo.mDomainPPSize ;
655
+ TLLM_CHECK_WITH_INFO (cacheBlockSizeSum % validTpSize == 0 ,
656
+ " cacheBlockSizeSum must be divisible by validTpSize %ld" , validTpSize);
657
+ TLLM_CHECK_WITH_INFO ((cacheBlockSizeSum % (selfAttentionLayerNum * validTpSize)) == 0 ,
658
+ " cacheBlockSizeSum must be divisible by validTpSize %ld * selfAttentionLayerNum %d" , validTpSize,
635
659
selfAttentionLayerNum);
636
660
TLLM_CHECK (targetNum == pickUpConnections.size ());
637
- size_t baseEleSize = cacheBlockSizeSum / (valideTpSize * selfAttentionLayerNum);
661
+ // the sum of buffer size is cacheBlockSizeSum.
662
+ size_t baseEleSize = cacheBlockSizeSum / (validTpSize * selfAttentionLayerNum);
638
663
639
664
std::vector<size_t > bufferEleSizes (targetNum, 0 );
640
665
std::vector<SizeType32> LayerNumbufferTargetNum (targetNum, 0 );
@@ -647,7 +672,7 @@ void CacheFormatter::unformat(TransferSession& session)
647
672
}
648
673
return std::make_pair (bufferEleSizes, LayerNumbufferTargetNum);
649
674
};
650
- auto [bufferEleSizes, LayerNumbufferTargetNum] = getTargetBufferEleSzie ();
675
+ auto [bufferEleSizes, LayerNumbufferTargetNum] = getTargetBufferEleSize ();
651
676
652
677
size_t remainNoCoverTargetNum = 0 ;
653
678
size_t bufferCoverTargetNum = 0 ;
0 commit comments