Skip to content

Commit 061fd2c

Browse files
committed
[KV cache manager] Support SWA KV cache reuse
This merge request hopes to support SWA kv cache reuse. Before this commit, when dealing with sliding window attention (SWA), the KV cache manager allocates "window size" number of blocks and reuse them in a cyclic matter. This behavior will not be able to support reuse since block contents will be overwritten. In this commit, we change such behavior to let the manager write blocks in a linear matter. With a linear block writing bahavior, as the attention window moves on, the out-of-window (OOW) blocks will be detached. Right now for the sake of a correct feature first, we directly offload the OOW block from the primary block pool (GPU memory) to the secondary block pool (host memory). We will improve this in the future by delegating the block movement to the eviction policy. Signed-off-by: eopXD <[email protected]>
1 parent 5a65080 commit 061fd2c

File tree

5 files changed

+968
-331
lines changed

5 files changed

+968
-331
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ static constexpr SizeType32 kPrimaryLevel = 0;
5353

5454
static constexpr SizeType32 kSecondaryLevel = 1;
5555

56+
// Extra block buffer allocated for SWA to be able to always keep "window size"
57+
// tokens held in the blocks.
58+
static constexpr SizeType32 kSWAExtraBlock = 1;
59+
5660
class KVCacheBlock;
5761
class BlockManager;
5862
class KVCacheManager;
@@ -88,8 +92,8 @@ struct WindowSizeMetadata
8892
SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
8993
SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
9094
SizeType32 numPools; // number of managed pools
91-
SizeType32 maxTokenNum; // Maximum token length (including bubble)
92-
SizeType32 maxBlocksPerSeq;
95+
SizeType32 maxTokensPerSeq; // Maximum token length per sequence (TODO: account for streamLLM)
96+
SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence
9397
SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize
9498
SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
9599
// Only needed when chunked context + sliding window attention are used
@@ -99,9 +103,9 @@ struct WindowSizeMetadata
99103
{
100104
return tensorrt_llm::common::fmtstr(
101105
"WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
102-
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
103-
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
104-
maxNumBlocks, temporaryAttentionWindow);
106+
".numPools=%d, .maxTokensPerSeq=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
107+
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokensPerSeq,
108+
maxBlocksPerSeq, maxNumBlocks, temporaryAttentionWindow);
105109
}
106110
};
107111

@@ -203,6 +207,7 @@ class KVCacheBlock
203207
using IdType = std::int32_t;
204208

205209
static constexpr IdType kCachedBlocksRootId = -1;
210+
static constexpr IdType kInvalidBlockId = -2;
206211

207212
explicit KVCacheBlock(IdType blockId, kernels::KVCacheIndex blockIdx);
208213

@@ -335,14 +340,7 @@ class GenerationRequest
335340
, mNumTokens(numTokens)
336341
, mBeamWidth(beamWidth)
337342
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
338-
// min window size + sink bubble length
339-
// Why use the minimum window size:
340-
// Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets
341-
// `mContextCurrentPosition` - this cannot be done for some windows sizes and
342-
// not for others, the state needs to remain identical for all window sizes. So
343-
// we currently resort to strictly disabling the reuse code path for all window
344-
// sizes at once or enable it for all window sizes at once.
345-
, mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum)
343+
, mNumFrontBlocksRemoved(0)
346344
{
347345
auto const numWindowSizes = windowSizeToMetadata.size();
348346
mCacheBlockIds.reserve(numWindowSizes);
@@ -385,6 +383,11 @@ class GenerationRequest
385383
return mNumTokens;
386384
}
387385

386+
[[nodiscard]] SizeType32 getNumFrontBlocksRemoved() const
387+
{
388+
return mNumFrontBlocksRemoved;
389+
}
390+
388391
[[nodiscard]] SizeType32 getBeamWidth() const
389392
{
390393
return mBeamWidth;
@@ -422,6 +425,26 @@ class GenerationRequest
422425
{
423426
beamBlockIds.clear();
424427
}
428+
mNumFrontBlocksRemoved = 0;
429+
}
430+
431+
void removeFrontBlock(SizeType32 windowSize)
432+
{
433+
for (auto& beamBlockIds : mCacheBlockIds.at(windowSize))
434+
{
435+
if (mNumFrontBlocksRemoved < static_cast<SizeType32>(beamBlockIds.size()))
436+
{
437+
// Doesn't actually remove from mCacheBlockIds like removeLastBlock,
438+
// block id is set to -1 instead because we preserve the blocks
439+
// for reuse when reuse is enabled.
440+
beamBlockIds[mNumFrontBlocksRemoved] = KVCacheBlock::kInvalidBlockId;
441+
}
442+
else
443+
{
444+
TLLM_LOG_WARNING("RequestID %d: removeFrontBlock called but nothing to remove", mRequestId);
445+
}
446+
}
447+
++mNumFrontBlocksRemoved;
425448
}
426449

427450
void removeLastBlock(SizeType32 windowSize)
@@ -442,14 +465,6 @@ class GenerationRequest
442465
return mKvCacheRetentionConfig.getDecodeDurationMs();
443466
}
444467

445-
// @brief Check whether the sequence uses cyclic KV cache.
446-
// @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
447-
// @details If `true`, we cannot store the sequence's KV cache for reuse.
448-
[[nodiscard]] bool isCyclic() const
449-
{
450-
return mNumTokens >= mCyclicThreshold;
451-
}
452-
453468
private:
454469
// Request id of the sequence
455470
LlmRequest::RequestIdType mRequestId;
@@ -463,9 +478,8 @@ class GenerationRequest
463478
std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices;
464479
// The retention priority to assign to decode blocks
465480
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;
466-
467-
// Number of tokens at which the KV Cache begins sliding [for the minimum attention window]
468-
SizeType32 mCyclicThreshold;
481+
// Number of front blocks removed from the sequence
482+
SizeType32 mNumFrontBlocksRemoved;
469483
};
470484

471485
// attach metadata to a pool pointer
@@ -533,7 +547,7 @@ class WindowBlockManager
533547

534548
explicit WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize,
535549
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
536-
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
550+
SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
537551
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
538552
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
539553
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
@@ -567,14 +581,26 @@ class WindowBlockManager
567581
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
568582

569583
//! \brief Release blocks of the sequence.
570-
void releaseBlocks(GenerationRequest& sequence);
584+
//! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
585+
void releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest = std::nullopt);
571586

572587
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
573588
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);
574589

590+
//! \brief Update cache offsets for last block
591+
void updateLastCacheBlockOffsets(GenerationRequest& seq);
592+
575593
//! \brief Release last block in the sequence
576594
void releaseLastBlock(GenerationRequest& sequence);
577595

596+
//! \brief Detach block from the sequence
597+
void detachBlock(GenerationRequest& sequence, bool isEnableBlockReuse);
598+
599+
//! \brief Check and add a block to the sequence if needed.
600+
//! \details Out-of-window blocks will be detached. If reuse is enabled,
601+
//! the detached block will be stored via offload.
602+
void addBlockIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
603+
578604
[[nodiscard]] SizeType32 getWindowSize() const noexcept
579605
{
580606
return mWindowSize;
@@ -585,7 +611,7 @@ class WindowBlockManager
585611
return mLogPrefix;
586612
}
587613

588-
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept;
614+
[[nodiscard]] SizeType32 getNumFreeBlocks(SizeType32 cacheLevel = kPrimaryLevel) const noexcept;
589615

590616
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
591617
{
@@ -715,7 +741,8 @@ class WindowBlockManager
715741
//! \brief Store blocks in cached blocks.
716742
//! \param blockKeys Key of each block.
717743
//! \param blockIds Id of each block.
718-
void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
744+
//! \return Number of actual blocks stored.
745+
SizeType32 storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
719746

720747
[[nodiscard]] bool verifyQueueIntegrity();
721748

@@ -796,6 +823,8 @@ class WindowBlockManager
796823
SizeType32 mSchedulingNumFreeBlocks;
797824
// Number of tokens per one block
798825
SizeType32 mTokensPerBlock;
826+
// Whether this window is SWA
827+
bool mIsSWA;
799828
// List of all blocks by idx
800829
std::vector<BlockPtr> mAllBlocksById;
801830
// Dummy block acting as root for BlockToken searches
@@ -917,19 +946,20 @@ class BlockManager
917946

918947
void startScheduling();
919948

920-
[[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize() const
949+
[[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize(
950+
SizeType32 cacheLevel = kPrimaryLevel) const
921951
{
922952
std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
923953
for (auto const& [windowSize, manager] : mWindowBlockManagers)
924954
{
925-
numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks();
955+
numFreeBlocksPerWindowSize[windowSize] = manager.getNumFreeBlocks(cacheLevel);
926956
}
927957
return numFreeBlocksPerWindowSize;
928958
}
929959

930-
[[nodiscard]] SizeType32 getNumFreeBlocks() const
960+
[[nodiscard]] SizeType32 getNumFreeBlocks(SizeType32 cacheLevel = kPrimaryLevel) const
931961
{
932-
return sumWindows([](auto const& manager) { return manager.getNumFreeBlocks(); });
962+
return sumWindows([cacheLevel](auto const& manager) { return manager.getNumFreeBlocks(cacheLevel); });
933963
}
934964

935965
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType32 numRequired, SizeType32 windowSize) const
@@ -1088,14 +1118,6 @@ class BlockManager
10881118
//! \brief Store newest block for reuse
10891119
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
10901120

1091-
[[nodiscard]] static bool isUseOneMoreBlock(
1092-
SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
1093-
{
1094-
bool const isCyclicWindowSize = maxSequenceLength.has_value() && maxSequenceLength.value() > windowSize;
1095-
bool const isBeamSearch = maxBeamWidth > 1;
1096-
return isCyclicWindowSize && isBeamSearch;
1097-
}
1098-
10991121
//! \brief Perform per-request bookkeeping
11001122
void refreshBlocks();
11011123

@@ -1114,12 +1136,12 @@ class BlockManager
11141136
//! \brief Update cache offsets for blocks initiated from sequence
11151137
void updateSequenceCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);
11161138

1117-
//! \brief Update cache offsets for last block
1118-
void updateLastCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);
1119-
11201139
//! \brief Update cache offsets for block at index
11211140
void updateCacheBlockOffsetsAtIdx(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);
11221141

1142+
//! \brief Add block to the sequence if needed
1143+
void addBlockIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
1144+
11231145
private:
11241146
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
11251147
{

0 commit comments

Comments
 (0)