Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,8 @@ class LRUEvictionPolicy : public BaseEvictionPolicy
bool verifyQueueIntegrity() override;

private:
// Check if the block should be added to mFreeQueues.
bool isReleasedLeafBlock(BlockPtr const& block);

// Queues of available leaf blocks, split by cache level and priority level
std::vector<std::vector<FreeBlocksQueue>> mFreeQueues;
// All blocks that have been released, along with the amount of released children
std::vector<std::unordered_set<SizeType32>> mReleasedBlocks;
// Iterators to block entries in mFreeQueues
std::vector<std::optional<FreeBlocksQueue::iterator>> mFreeBlockIterators;
// Amount of free blocks at each cache level
Expand Down
111 changes: 92 additions & 19 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ struct WindowSizeMetadata
SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
// Only needed when chunked context + sliding window attention are used
// together. And it should only be considered when allocating blocks.
SizeType32 windowSize;
bool isSWA;

std::string toString()
{
return tensorrt_llm::common::fmtstr(
"WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }",
".numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d, "
".windowSize=%d, .isSWA=%d }",
allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
maxNumBlocks, temporaryAttentionWindow);
maxNumBlocks, temporaryAttentionWindow, windowSize, isSWA);
}
};

Expand Down Expand Up @@ -512,6 +515,8 @@ class GenerationRequest
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;
// Number of front blocks removed from the sequence
SizeType32 mNumFrontBlocksRemoved;
// Set of used blocks by the sequence
std::set<KVCacheBlock::IdType> mUsedBlocks;
};

// attach metadata to a pool pointer
Expand Down Expand Up @@ -628,15 +633,15 @@ class WindowBlockManager
void releaseLastBlock(GenerationRequest& sequence);

//! \brief Detach front block from the sequence
void detachFrontBlock(GenerationRequest& sequence, bool isEnableBlockReuse);
void detachFrontBlock(GenerationRequest& sequence);

//! \brief Add/detach block(s) to/from the sequence if needed
//! \details When we need a new block, we add it. For sliding window
//! attention (SWA), when a block goes out-of-window (OOW), we detach it
//! and store it if reuse is enabled. If this called in the first step of
//! the generation phase, we may detach more than a single block since
//! there may be more than one context block that goes OOW.
void adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
//! If this called in the first step of the generation phase, we may detach
//! more than a single block since there may be more than one context block
//! that goes OOW.
void adjustBlocksIfNeeded(GenerationRequest& sequence);

[[nodiscard]] SizeType32 getWindowSize() const noexcept
{
Expand Down Expand Up @@ -763,7 +768,7 @@ class WindowBlockManager

//! \brief Bring offloaded block from secondary to primary memory.
//! \details Does nothing if block is already in primary memory.
void onboardBlock(BlockPtr const& offloadBlock,
void onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");

//! \brief Bring block from primary to secondary memory.
Expand Down Expand Up @@ -826,6 +831,23 @@ class WindowBlockManager
//! \brief Unpin blocks by starting from a block id and walking prev pointers.
void unpinBlocksById(KVCacheBlock::IdType blockId);

void initializeSequenceStorageValidity(LlmRequest::RequestIdType requestId)
{
mIsValidStoreForReuseSequence[requestId] = true;
}

void releaseSequenceStorageValidity(LlmRequest::RequestIdType requestId)
{
mIsValidStoreForReuseSequence.erase(requestId);
}

//! \brief Return whether this sequence is valid for store for reuse
[[nodiscard]] bool isSequenceValidForStoreForReuse(LlmRequest::RequestIdType requestId) const
{
TLLM_CHECK_WITH_INFO(mIsValidStoreForReuseSequence.count(requestId) > 0, "Sequence should be bookkeeped");
return mIsValidStoreForReuseSequence.at(requestId);
}

private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
Expand All @@ -842,18 +864,17 @@ class WindowBlockManager
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");

//! \brief Free block and all it's descendants. This makes block a claimed leaf block.
void freeChildren(BlockPtr const& block, executor::RetentionPriority priority,
std::optional<std::chrono::milliseconds> durationMs);
void freeChildren(BlockPtr const& block);

//! \brief Find block least likely to be reused, free it if necessary and return.
[[nodiscard]] BlockPtr getFreeBlock(
//! \param sequence Sequence which the free block is allocated for
[[nodiscard]] BlockPtr getFreeBlock(GenerationRequest& sequence,
executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority,
std::optional<std::chrono::milliseconds> durationMs = std::nullopt,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");

//! \brief Free block from previous block and claim it from free blocks list.
void claimLeafBlock(BlockPtr const& block, std::optional<executor::RetentionPriority> priority = std::nullopt,
std::optional<std::chrono::milliseconds> durationMs = std::nullopt);
//! \brief Calls KVCacheBlock::freeLeafBlock to remove block from search tree.
void freeLeafBlock(BlockPtr const& block);

//! \brief For FP4 quantization. Creates pool objects for FP4 block scalars.
void createBlockScalePools(SizeType32 blockSize);
Expand Down Expand Up @@ -933,6 +954,14 @@ class WindowBlockManager

// Mutex for the cached blocks root
std::mutex mCachedBlocksRootMutex;

// Record which sequence is using the block
std::map<KVCacheBlock::IdType, LlmRequest::RequestIdType> mBlockToSequence;
// Record whether a sequence has all blocks held valid.
// The boolean value is set to true upon first encounter of a new sequence.
// It may be invalidated to false when other sequence acquires a block that
// is used by another sequence.
std::map<LlmRequest::RequestIdType, bool> mIsValidStoreForReuseSequence;
};

class BlockManager
Expand Down Expand Up @@ -1008,7 +1037,7 @@ class BlockManager

//! \brief Bring block from primary to secondary memory for window size.
//! \details Does nothing if block is already in primary memory.
void onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize,
void onboardBlock(GenerationRequest& sequence, BlockPtr const& offloadBlock, SizeType32 windowSize,
executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = "");

//! \brief Bring block from primary to secondary memory for window size.
Expand Down Expand Up @@ -1239,10 +1268,52 @@ class BlockManager
//! \brief Add/detach block(s) to/from the sequence if needed
//! \details When we need a new block, we add it. For sliding window
//! attention (SWA), when a block goes out-of-window (OOW), we detach it
//! and store it if reuse is enabled. If this called in the first step of
//! the generation phase, we may detach more than a single block since
//! there may be more than one context block that goes OOW.
void adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);
//! If this called in the first step of the generation phase, we may
//! detach more than a single block since there may be more than one
//! context block that goes OOW.
void adjustBlocksIfNeeded(GenerationRequest& sequence);

//! \brief Return whether the sequence is already managed by the block manager
[[nodiscard]] bool isSequenceHeld(LlmRequest::RequestIdType requestId) const
{
return mManagedSequences.count(requestId) > 0;
}

//! \brief Add a sequence to the managed sequences
//! \details Take the sequence into account for the manager. Initialize
//! sequence storage validity under all window sizes.
void holdSequence(LlmRequest::RequestIdType requestId)
{
mManagedSequences.insert(requestId);
for (auto const& [windowSize, metadata] : mWindowSizeToMetadata)
{
mWindowBlockManagers.at(windowSize).initializeSequenceStorageValidity(requestId);
}
}

//! \brief Remove a sequence from the managed sequences.
//! \details Remove sequence from the managed sequences and remove sequence
//! storage
void releaseSequence(LlmRequest::RequestIdType requestId)
{
mManagedSequences.erase(requestId);
for (auto const& [windowSize, metadata] : mWindowSizeToMetadata)
{
mWindowBlockManagers.at(windowSize).releaseSequenceStorageValidity(requestId);
}
}

//! \brief Return whether the sequence is still valid for store-for-reuse
//! regarding the specific window size.
//! \details Currently this utility function is only used under
//! kvCacheManagerTest.cpp. Checking for store-for-reuse for each window
//! size is done in an iterating fashion under BlockManager::releaseBlocks.
bool isSequenceValidForStoreForReuse(LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
{
TLLM_CHECK_WITH_INFO(
mWindowBlockManagers.count(windowSize) > 0, "Querying window size is not found under mWindowBlockManager");
return mWindowBlockManagers.at(windowSize).isSequenceValidForStoreForReuse(requestId);
}

private:
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
Expand Down Expand Up @@ -1278,6 +1349,8 @@ class BlockManager
std::vector<SizeType32> mLayerToWindowSize;
std::vector<SizeType32> mAbsolutePoolToWindowSize;
std::vector<SizeType32> mAbsolutePoolToRelativePoolIndex;
// Record what sequences are currently managed by the block manager
std::set<LlmRequest::RequestIdType> mManagedSequences;
};

struct OffsetTableDimensions
Expand Down
72 changes: 7 additions & 65 deletions cpp/tensorrt_llm/batch_manager/evictionPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ void LRUEvictionPolicy::initialize(std::vector<BlockPtr>& mAllBlocksById, std::v
{
mFreeBlockIterators.reserve(mFreeBlockIterators.size() + sizes[cacheLevel]);
mFreeQueues.emplace_back(std::vector<FreeBlocksQueue>(kMaxPriority - kMinPriority + 1));
mReleasedBlocks.emplace_back(std::unordered_set<SizeType32>());

auto& freeQueue = mFreeQueues[cacheLevel][defaultPriorityIdx];

for (SizeType32 blockId = 0; blockId < sizes[cacheLevel]; blockId++)
{
// Initialize all blocks to be the default priority level
mFreeBlockIterators.emplace_back(freeQueue.insert(freeQueue.end(), mAllBlocksById[startIdx + blockId]));
mReleasedBlocks[cacheLevel].insert(startIdx + blockId);
}

startIdx += sizes[cacheLevel];
Expand Down Expand Up @@ -134,35 +132,15 @@ void LRUEvictionPolicy::releaseBlock(BlockPtr block, bool toFront)
SizeType32 const cacheLevel = getCacheLevel(block);
SizeType32 const id = block->getBlockId();

mReleasedBlocks[cacheLevel].insert(id);

// It's possible that this block is the child of a matched block that's in mFreeQueues. If this happens, we need to
// remove the parent from mFreeQueues, since it's no longer a released leaf block.
auto parent = block->getPrevBlock();
if (parent != nullptr)
// If there are no children, this is a leaf block. Insert into a queue.
auto& q = mFreeQueues[cacheLevel][getPriorityIdx(block->getPriority())];
if (toFront)
{
auto const parentId = parent->getBlockId();
if (parentId != KVCacheBlock::kCachedBlocksRootId && mFreeBlockIterators[parent->getBlockId()] != std::nullopt
&& !isReleasedLeafBlock(parent))
{
mFreeQueues[getCacheLevel(parent)][getPriorityIdx(parent->getPriority())].erase(
*mFreeBlockIterators[parentId]);
mFreeBlockIterators[parentId] = std::nullopt;
}
mFreeBlockIterators[id] = q.insert(q.begin(), block);
}

if (mFreeBlockIterators[block->getBlockId()] == std::nullopt && isReleasedLeafBlock(block))
else
{
// If there are no children, this is a leaf block. Insert into a queue.
auto& q = mFreeQueues[cacheLevel][getPriorityIdx(block->getPriority())];
if (toFront)
{
mFreeBlockIterators[id] = q.insert(q.begin(), block);
}
else
{
mFreeBlockIterators[id] = q.insert(q.end(), block);
}
mFreeBlockIterators[id] = q.insert(q.end(), block);
}

mNumFreeBlocksPerLevel[cacheLevel]++;
Expand Down Expand Up @@ -192,24 +170,10 @@ void LRUEvictionPolicy::claimBlock(BlockPtr block, std::optional<executor::Reten
SizeType32 const id = block->getBlockId();
SizeType32 const cacheLevel = getCacheLevel(block);

if (mReleasedBlocks[cacheLevel].find(id) != mReleasedBlocks[cacheLevel].end())
{
mNumFreeBlocksPerLevel[cacheLevel] -= 1;
mReleasedBlocks[cacheLevel].erase(id);
}

if (mFreeBlockIterators[id] != std::nullopt)
{
mFreeQueues[cacheLevel][getPriorityIdx(block->getPriority())].erase(*mFreeBlockIterators[id]);

BlockPtr const parent = block->getPrevBlock();

if (parent.get() != nullptr && parent->getBlockId() != KVCacheBlock::kCachedBlocksRootId
&& mFreeBlockIterators[parent->getBlockId()] == std::nullopt && isReleasedLeafBlock(parent))
{
auto& q = mFreeQueues[getCacheLevel(parent)][getPriorityIdx(parent->getPriority())];
mFreeBlockIterators[parent->getBlockId()] = q.insert(q.end(), parent);
}
mNumFreeBlocksPerLevel[cacheLevel] -= 1;
}

mFreeBlockIterators[id] = std::nullopt;
Expand All @@ -223,28 +187,6 @@ void LRUEvictionPolicy::claimBlock(BlockPtr block, std::optional<executor::Reten
block->setDurationMs(durationMs);
}

bool LRUEvictionPolicy::isReleasedLeafBlock(BlockPtr const& block)
{
SizeType32 const blockCacheLevel = getCacheLevel(block);

if (mReleasedBlocks[blockCacheLevel].find(block->getBlockId()) == mReleasedBlocks[blockCacheLevel].end())
{
return false;
}

for (auto const& p : block->getNextBlocks())
{
SizeType32 const childCacheLevel = getCacheLevel(p.second);
if (mReleasedBlocks[childCacheLevel].find(p.second->getBlockId()) != mReleasedBlocks[childCacheLevel].end()
&& childCacheLevel <= blockCacheLevel)
{
return false;
}
}

return true;
}

std::chrono::steady_clock::time_point::duration LRUEvictionPolicy::getTime() const
{
return std::chrono::steady_clock::now().time_since_epoch();
Expand Down
Loading