@@ -53,6 +53,10 @@ static constexpr SizeType32 kPrimaryLevel = 0;
53
53
54
54
static constexpr SizeType32 kSecondaryLevel = 1 ;
55
55
56
+ // Extra block buffer allocated for SWA to be able to always keep "window size"
57
+ // tokens held in the blocks.
58
+ static constexpr SizeType32 kSWAExtraBlock = 1 ;
59
+
56
60
class KVCacheBlock ;
57
61
class BlockManager ;
58
62
class KVCacheManager ;
@@ -88,8 +92,8 @@ struct WindowSizeMetadata
88
92
SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
89
93
SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
90
94
SizeType32 numPools; // number of managed pools
91
- SizeType32 maxTokenNum ; // Maximum token length (including bubble )
92
- SizeType32 maxBlocksPerSeq;
95
+ SizeType32 maxTokensPerSeq ; // Maximum token length per sequence (TODO: account for streamLLM )
96
+ SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence
93
97
SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize
94
98
SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
95
99
// Only needed when chunked context + sliding window attention are used
@@ -99,9 +103,9 @@ struct WindowSizeMetadata
99
103
{
100
104
return tensorrt_llm::common::fmtstr (
101
105
" WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
102
- " .numPools=%d, .maxTokenNum =%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }" ,
103
- allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq ,
104
- maxNumBlocks, temporaryAttentionWindow);
106
+ " .numPools=%d, .maxTokensPerSeq =%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }" ,
107
+ allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokensPerSeq ,
108
+ maxBlocksPerSeq, maxNumBlocks, temporaryAttentionWindow);
105
109
}
106
110
};
107
111
@@ -203,6 +207,7 @@ class KVCacheBlock
203
207
using IdType = std::int32_t ;
204
208
205
209
static constexpr IdType kCachedBlocksRootId = -1 ;
210
+ static constexpr IdType kInvalidBlockId = -2 ;
206
211
207
212
explicit KVCacheBlock (IdType blockId, kernels::KVCacheIndex blockIdx);
208
213
@@ -335,14 +340,7 @@ class GenerationRequest
335
340
, mNumTokens(numTokens)
336
341
, mBeamWidth(beamWidth)
337
342
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
338
- // min window size + sink bubble length
339
- // Why use the minimum window size:
340
- // Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets
341
- // `mContextCurrentPosition` - this cannot be done for some windows sizes and
342
- // not for others, the state needs to remain identical for all window sizes. So
343
- // we currently resort to strictly disabling the reuse code path for all window
344
- // sizes at once or enable it for all window sizes at once.
345
- , mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum)
343
+ , mNumFrontBlocksRemoved(0 )
346
344
{
347
345
auto const numWindowSizes = windowSizeToMetadata.size ();
348
346
mCacheBlockIds .reserve (numWindowSizes);
@@ -385,6 +383,11 @@ class GenerationRequest
385
383
return mNumTokens ;
386
384
}
387
385
386
+ [[nodiscard]] SizeType32 getNumFrontBlocksRemoved () const
387
+ {
388
+ return mNumFrontBlocksRemoved ;
389
+ }
390
+
388
391
[[nodiscard]] SizeType32 getBeamWidth () const
389
392
{
390
393
return mBeamWidth ;
@@ -422,6 +425,26 @@ class GenerationRequest
422
425
{
423
426
beamBlockIds.clear ();
424
427
}
428
+ mNumFrontBlocksRemoved = 0 ;
429
+ }
430
+
431
+ void removeFrontBlock (SizeType32 windowSize)
432
+ {
433
+ for (auto & beamBlockIds : mCacheBlockIds .at (windowSize))
434
+ {
435
+ if (mNumFrontBlocksRemoved < static_cast <SizeType32>(beamBlockIds.size ()))
436
+ {
437
+ // Doesn't actually remove from mCacheBlockIds like removeLastBlock,
438
+ // block id is set to -1 instead because we preserve the blocks
439
+ // for reuse when reuse is enabled.
440
+ beamBlockIds[mNumFrontBlocksRemoved ] = KVCacheBlock::kInvalidBlockId ;
441
+ }
442
+ else
443
+ {
444
+ TLLM_LOG_WARNING (" RequestID %d: removeFrontBlock called but nothing to remove" , mRequestId );
445
+ }
446
+ }
447
+ ++mNumFrontBlocksRemoved ;
425
448
}
426
449
427
450
void removeLastBlock (SizeType32 windowSize)
@@ -442,14 +465,6 @@ class GenerationRequest
442
465
return mKvCacheRetentionConfig .getDecodeDurationMs ();
443
466
}
444
467
445
- // @brief Check whether the sequence uses cyclic KV cache.
446
- // @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
447
- // @details If `true`, we cannot store the sequence's KV cache for reuse.
448
- [[nodiscard]] bool isCyclic () const
449
- {
450
- return mNumTokens >= mCyclicThreshold ;
451
- }
452
-
453
468
private:
454
469
// Request id of the sequence
455
470
LlmRequest::RequestIdType mRequestId ;
@@ -463,9 +478,8 @@ class GenerationRequest
463
478
std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices ;
464
479
// The retention priority to assign to decode blocks
465
480
executor::KvCacheRetentionConfig mKvCacheRetentionConfig ;
466
-
467
- // Number of tokens at which the KV Cache begins sliding [for the minimum attention window]
468
- SizeType32 mCyclicThreshold ;
481
+ // Number of front blocks removed from the sequence
482
+ SizeType32 mNumFrontBlocksRemoved ;
469
483
};
470
484
471
485
// attach metadata to a pool pointer
@@ -533,7 +547,7 @@ class WindowBlockManager
533
547
534
548
explicit WindowBlockManager (nvinfer1::DataType dtype, SizeType32 windowSize,
535
549
std::vector<SizeType32> const & managedLayers, std::vector<SizeType32> const & numKvHeadsPerLayer,
536
- SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
550
+ SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
537
551
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
538
552
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
539
553
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
@@ -567,14 +581,26 @@ class WindowBlockManager
567
581
void storeNewBlock (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
568
582
569
583
// ! \brief Release blocks of the sequence.
570
- void releaseBlocks (GenerationRequest& sequence);
584
+ // ! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
585
+ void releaseBlocks (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest = std::nullopt );
571
586
572
587
// ! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
573
588
void schedulingReleaseBlocks (LlmRequest::RequestIdType requestId);
574
589
590
+ // ! \brief Update cache offsets for last block
591
+ void updateLastCacheBlockOffsets (GenerationRequest& seq);
592
+
575
593
// ! \brief Release last block in the sequence
576
594
void releaseLastBlock (GenerationRequest& sequence);
577
595
596
+ // ! \brief Detach block from the sequence
597
+ void detachBlock (GenerationRequest& sequence, bool isEnableBlockReuse);
598
+
599
+ // ! \brief Check and add a block to the sequence if needed.
600
+ // ! \details Out-of-window blocks will be detached. If reuse is enabled,
601
+ // ! the detached block will be stored via offload.
602
+ void addBlockIfNeeded (GenerationRequest& sequence, bool isEnableBlockReuse);
603
+
578
604
[[nodiscard]] SizeType32 getWindowSize () const noexcept
579
605
{
580
606
return mWindowSize ;
@@ -585,7 +611,7 @@ class WindowBlockManager
585
611
return mLogPrefix ;
586
612
}
587
613
588
- [[nodiscard]] SizeType32 getNumFreeBlocks () const noexcept ;
614
+ [[nodiscard]] SizeType32 getNumFreeBlocks (SizeType32 cacheLevel = kPrimaryLevel ) const noexcept ;
589
615
590
616
[[nodiscard]] SizeType32 getNumAllocTotalBlocks () const
591
617
{
@@ -715,7 +741,8 @@ class WindowBlockManager
715
741
// ! \brief Store blocks in cached blocks.
716
742
// ! \param blockKeys Key of each block.
717
743
// ! \param blockIds Id of each block.
718
- void storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds);
744
+ // ! \return Number of actual blocks stored.
745
+ SizeType32 storeBlocks (std::vector<BlockKey> const & blockKeys, std::vector<KVCacheBlock::IdType> const & blockIds);
719
746
720
747
[[nodiscard]] bool verifyQueueIntegrity ();
721
748
@@ -796,6 +823,8 @@ class WindowBlockManager
796
823
SizeType32 mSchedulingNumFreeBlocks ;
797
824
// Number of tokens per one block
798
825
SizeType32 mTokensPerBlock ;
826
+ // Whether this window is SWA
827
+ bool mIsSWA ;
799
828
// List of all blocks by idx
800
829
std::vector<BlockPtr> mAllBlocksById ;
801
830
// Dummy block acting as root for BlockToken searches
@@ -917,19 +946,20 @@ class BlockManager
917
946
918
947
void startScheduling ();
919
948
920
- [[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize () const
949
+ [[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize (
950
+ SizeType32 cacheLevel = kPrimaryLevel ) const
921
951
{
922
952
std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
923
953
for (auto const & [windowSize, manager] : mWindowBlockManagers )
924
954
{
925
- numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks ();
955
+ numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks (cacheLevel );
926
956
}
927
957
return numFreeBlocksPerWindowSize;
928
958
}
929
959
930
- [[nodiscard]] SizeType32 getNumFreeBlocks () const
960
+ [[nodiscard]] SizeType32 getNumFreeBlocks (SizeType32 cacheLevel = kPrimaryLevel ) const
931
961
{
932
- return sumWindows ([](auto const & manager) { return manager.getNumFreeBlocks (); });
962
+ return sumWindows ([cacheLevel ](auto const & manager) { return manager.getNumFreeBlocks (cacheLevel ); });
933
963
}
934
964
935
965
[[nodiscard]] bool schedulingHasFreeBlocks (SizeType32 numRequired, SizeType32 windowSize) const
@@ -1088,14 +1118,6 @@ class BlockManager
1088
1118
// ! \brief Store newest block for reuse
1089
1119
void storeNewBlock (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest);
1090
1120
1091
- [[nodiscard]] static bool isUseOneMoreBlock (
1092
- SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
1093
- {
1094
- bool const isCyclicWindowSize = maxSequenceLength.has_value () && maxSequenceLength.value () > windowSize;
1095
- bool const isBeamSearch = maxBeamWidth > 1 ;
1096
- return isCyclicWindowSize && isBeamSearch;
1097
- }
1098
-
1099
1121
// ! \brief Perform per-request bookkeeping
1100
1122
void refreshBlocks ();
1101
1123
@@ -1114,12 +1136,12 @@ class BlockManager
1114
1136
// ! \brief Update cache offsets for blocks initiated from sequence
1115
1137
void updateSequenceCacheBlockOffsets (GenerationRequest& seq, SizeType32 windowSize);
1116
1138
1117
- // ! \brief Update cache offsets for last block
1118
- void updateLastCacheBlockOffsets (GenerationRequest& seq, SizeType32 windowSize);
1119
-
1120
1139
// ! \brief Update cache offsets for block at index
1121
1140
void updateCacheBlockOffsetsAtIdx (GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);
1122
1141
1142
+ // ! \brief Add block to the sequence if needed
1143
+ void addBlockIfNeeded (GenerationRequest& sequence, bool isEnableBlockReuse);
1144
+
1123
1145
private:
1124
1146
[[nodiscard]] WindowBlockManager const & windowManagerByLayer (SizeType32 layerIdx) const
1125
1147
{
0 commit comments