Skip to content

Commit 0dc173e

Browse files
committed
[kv cache manager] Implement the detach mechanism for SWA reuse
This MR is a continuation of NVIDIA#6768. In the previous merge request, OOW (out-of-window) blocks are only detached when reuse is not enabled. This MR enables KV cache manager to detach the block when reuse is enabled. When reuse is enabled, OOW block gets detached while the manager bookkeep whether the block is overwritten by another sequence. If at the end of the sequence all blocks for it is still clean, then the sequence block will be stored. Signed-off-by: eopXD <[email protected]>
1 parent 948b8b9 commit 0dc173e

File tree

3 files changed

+627
-59
lines changed

3 files changed

+627
-59
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,17 @@ struct WindowSizeMetadata
130130
SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
131131
// Only needed when chunked context + sliding window attention are used
132132
// together. And it should only be considered when allocating blocks.
133+
SizeType32 windowSize;
134+
bool isSWA;
133135

134136
std::string toString()
135137
{
136138
return tensorrt_llm::common::fmtstr(
137139
"WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
138-
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
140+
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d, "
141+
".windowSize=%d, .isSWA=%d }",
139142
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
140-
maxNumBlocks, temporaryAttentionWindow);
143+
maxNumBlocks, temporaryAttentionWindow, windowSize, isSWA);
141144
}
142145
};
143146

@@ -512,6 +515,8 @@ class GenerationRequest
512515
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;
513516
// Number of front blocks removed from the sequence
514517
SizeType32 mNumFrontBlocksRemoved;
518+
// Set of used blocks by the sequence
519+
std::set<KVCacheBlock::IdType> mUsedBlocks;
515520
};
516521

517522
// attach metadata to a pool pointer
@@ -763,7 +768,7 @@ class WindowBlockManager
763768

764769
//! \brief Bring offloaded block from secondary to primary memory.
765770
//! \details Does nothing if block is already in primary memory.
766-
void onboardBlock(BlockPtr const& offloadBlock,
771+
void onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock,
767772
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
768773

769774
//! \brief Bring block from primary to secondary memory.
@@ -826,6 +831,23 @@ class WindowBlockManager
826831
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
827832
void unpinBlocksById(KVCacheBlock::IdType blockId);
828833

834+
void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
835+
{
836+
mIsValidStoreForReuseSequence[requestId] = true;
837+
}
838+
839+
void releaseSequenceStorageValidity(LlmRequest::RequestIdType requestId)
840+
{
841+
mIsValidStoreForReuseSequence.erase(requestId);
842+
}
843+
844+
//! \brief Return whether this sequence is valid for store for reuse
845+
[[nodiscard]] bool isSequenceValidForStoreForReuse(LlmRequest::RequestIdType requestId) const
846+
{
847+
TLLM_CHECK_WITH_INFO(mIsValidStoreForReuseSequence.count(requestId) > 0, "Sequence should be bookkeeped");
848+
return mIsValidStoreForReuseSequence.at(requestId);
849+
}
850+
829851
private:
830852
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
831853
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -846,7 +868,8 @@ class WindowBlockManager
846868
std::optional<std::chrono::milliseconds> durationMs);
847869

848870
//! \brief Find block least likely to be reused, free it if necessary and return.
849-
[[nodiscard]] BlockPtr getFreeBlock(
871+
//! \param sequence Sequence which the free block is allocated for
872+
[[nodiscard]] BlockPtr getFreeBlock(GenerationRequest& sequence,
850873
executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority,
851874
std::optional<std::chrono::milliseconds> durationMs = std::nullopt,
852875
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
@@ -933,6 +956,14 @@ class WindowBlockManager
933956

934957
// Mutex for the cached blocks root
935958
std::mutex mCachedBlocksRootMutex;
959+
960+
// Record which sequence is using the block
961+
std::map<KVCacheBlock::IdType, LlmRequest::RequestIdType> mBlockToSequence;
962+
// Record whether a sequence has all blocks held valid.
963+
// The boolean value is set to true upon first encounter of a new sequence.
964+
// It may be invalidated to false when other sequence acquires a block that
965+
// is used by another sequence.
966+
std::map<LlmRequest::RequestIdType, bool> mIsValidStoreForReuseSequence;
936967
};
937968

938969
class BlockManager
@@ -1008,7 +1039,7 @@ class BlockManager
10081039

10091040
//! \brief Bring block from primary to secondary memory for window size.
10101041
//! \details Does nothing if block is already in primary memory.
1011-
void onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize,
1042+
void onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock, SizeType32 windowSize,
10121043
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");
10131044

10141045
//! \brief Bring block from primary to secondary memory for window size.
@@ -1244,6 +1275,48 @@ class BlockManager
12441275
//! there may be more than one context block that goes OOW.
12451276
void adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
12461277

1278+
//! \brief Return whether the sequence is already managed by the block manager
1279+
[[nodiscard]] bool isSequenceHeld(LlmRequest::RequestIdType requestId) const
1280+
{
1281+
return mManagedSequences.count(requestId) > 0;
1282+
}
1283+
1284+
//! \brief Add a sequence to the managed sequences
1285+
//! \details Take the sequence into account for the manager. Initialize
1286+
//! sequence storage validity under all window sizes.
1287+
void holdSequence(LlmRequest::RequestIdType requestId)
1288+
{
1289+
mManagedSequences.insert(requestId);
1290+
for (auto const& [windowSize, metadata] : mWindowSizeToMetadata)
1291+
{
1292+
mWindowBlockManagers.at(windowSize).initializeSequenceStorageValidity(requestId);
1293+
}
1294+
}
1295+
1296+
//! \brief Remove a sequence from the managed sequences.
1297+
//! \details Remove sequence from the managed sequences and remove sequence
1298+
//! storage
1299+
void releaseSequence(LlmRequest::RequestIdType requestId)
1300+
{
1301+
mManagedSequences.erase(requestId);
1302+
for (auto const& [windowSize, metadata] : mWindowSizeToMetadata)
1303+
{
1304+
mWindowBlockManagers.at(windowSize).releaseSequenceStorageValidity(requestId);
1305+
}
1306+
}
1307+
1308+
//! \brief Return whether the sequence is still valid for store-for-reuse
1309+
//! regarding the specific window size.
1310+
//! \details Currently this utility function is only used under
1311+
//! kvCacheManagerTest.cpp. Checking for store-for-reuse for each window
1312+
//! size is done in an iterating fashion under BlockManager::releaseBlocks.
1313+
bool isSequenceValidForStoreForReuse(LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
1314+
{
1315+
TLLM_CHECK_WITH_INFO(
1316+
mWindowBlockManagers.count(windowSize) > 0, "Querying window size is not found under mWindowBlockManager");
1317+
return mWindowBlockManagers.at(windowSize).isSequenceValidForStoreForReuse(requestId);
1318+
}
1319+
12471320
private:
12481321
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
12491322
{
@@ -1278,6 +1351,8 @@ class BlockManager
12781351
std::vector<SizeType32> mLayerToWindowSize;
12791352
std::vector<SizeType32> mAbsolutePoolToWindowSize;
12801353
std::vector<SizeType32> mAbsolutePoolToRelativePoolIndex;
1354+
// Record what sequences are currently managed by the block manager
1355+
std::set<LlmRequest::RequestIdType> mManagedSequences;
12811356
};
12821357

12831358
struct OffsetTableDimensions

0 commit comments

Comments
 (0)