Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 65 additions & 47 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ static constexpr SizeType32 kPrimaryLevel = 0;

static constexpr SizeType32 kSecondaryLevel = 1;

// Extra block buffer allocated for SWA to be able to always keep "window size"
// tokens held in the blocks.
static constexpr SizeType32 kSWAExtraBlock = 1;

class KVCacheBlock;
class BlockManager;
class KVCacheManager;
Expand Down Expand Up @@ -93,8 +97,8 @@ struct WindowSizeMetadata
SizeType32 allottedSecondaryBlocks; // Number of secondary blocks allotted to the windowSize
SizeType32 absolutePoolsOffset; // cumulative number of pools up to manager
SizeType32 numPools; // number of managed pools
SizeType32 maxTokenNum; // Maximum token length (including bubble)
SizeType32 maxBlocksPerSeq;
SizeType32 maxTokenNum; // Maximum token length per sequence (TODO: account for streamLLM)
SizeType32 maxBlocksPerSeq; // Maximum number of blocks per sequence
SizeType32 maxNumBlocks; // Number of primary+secondary blocks allotted to the windowSize
SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
// Only needed when chunked context + sliding window attention are used
Expand Down Expand Up @@ -344,14 +348,7 @@ class GenerationRequest
, mNumTokens(numTokens)
, mBeamWidth(beamWidth)
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
// min window size + sink bubble length
// Why use the minimum window size:
// Chunked Prefill + Reuse calls `setPrepopulatedPromptLen()` which sets
// `mContextCurrentPosition` - this cannot be done for some windows sizes and
// not for others, the state needs to remain identical for all window sizes. So
// we currently resort to strictly disabling the reuse code path for all window
// sizes at once or enable it for all window sizes at once.
, mCyclicThreshold(windowSizeToMetadata.cbegin()->second.maxTokenNum)
, mNumFrontBlocksRemoved(0)
{
auto const numWindowSizes = windowSizeToMetadata.size();
mCacheBlockIds.reserve(numWindowSizes);
Expand Down Expand Up @@ -394,6 +391,11 @@ class GenerationRequest
return mNumTokens;
}

[[nodiscard]] SizeType32 getNumFrontBlocksRemoved() const
{
return mNumFrontBlocksRemoved;
}

[[nodiscard]] SizeType32 getBeamWidth() const
{
return mBeamWidth;
Expand Down Expand Up @@ -431,6 +433,12 @@ class GenerationRequest
{
beamBlockIds.clear();
}
mNumFrontBlocksRemoved = 0;
}

void removeFrontBlock(SizeType32 windowSize)
{
++mNumFrontBlocksRemoved;
}

void removeLastBlock(SizeType32 windowSize)
Expand Down Expand Up @@ -461,14 +469,6 @@ class GenerationRequest
return mKvCacheRetentionConfig.getDirectory();
}

// @brief Check whether the sequence uses cyclic KV cache.
// @return `true` if we have begun overwriting the beginning of the sequence's KV cache.
// @details If `true`, we cannot store the sequence's KV cache for reuse.
[[nodiscard]] bool isCyclic() const
{
return mNumTokens >= mCyclicThreshold;
}

private:
// Request id of the sequence
LlmRequest::RequestIdType mRequestId;
Expand All @@ -482,9 +482,8 @@ class GenerationRequest
std::unordered_map<SizeType32, runtime::ITensor::SharedPtr> mCacheBlockIndices;
// The retention priority to assign to decode blocks
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;

// Number of tokens at which the KV Cache begins sliding [for the minimum attention window]
SizeType32 mCyclicThreshold;
// Number of front blocks removed from the sequence
SizeType32 mNumFrontBlocksRemoved;
};

// attach metadata to a pool pointer
Expand Down Expand Up @@ -550,7 +549,7 @@ class WindowBlockManager

explicit WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize,
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
Expand Down Expand Up @@ -581,19 +580,32 @@ class WindowBlockManager
//! \brief Get the ids of all newly allocated (not reused) blocks for the sequence.
std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds(GenerationRequest const& sequence) const;

void storeBlocksForReuse(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

//! \brief Release blocks of the sequence.
void releaseBlocks(GenerationRequest& sequence);
//! \details When llmRequest is provided and reuse is enabled, blocks will be stored.
void releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
void schedulingReleaseBlocks(LlmRequest::RequestIdType requestId);

//! \brief Update cache offsets for last block
void updateLastCacheBlockOffsets(GenerationRequest& seq);

//! \brief Release last block in the sequence
void releaseLastBlock(GenerationRequest& sequence);

//! \brief Detach front block from the sequence
void detachFrontBlock(GenerationRequest& sequence, bool isEnableBlockReuse);

//! \brief Add/detach block(s) to/from the sequence if needed
//! \details When we need a new block, we add it. For sliding window
//! attention (SWA), when a block goes out-of-window (OOW), we detach it
//! and store it if reuse is enabled. If this called in the first step of
//! the generation phase, we may detach more than a single block since
//! there may be more than one context block that goes OOW.
void adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);

[[nodiscard]] SizeType32 getWindowSize() const noexcept
{
return mWindowSize;
Expand Down Expand Up @@ -745,7 +757,8 @@ class WindowBlockManager
//! \brief Store blocks in cached blocks.
//! \param blockKeys Key of each block.
//! \param blockIds Id of each block.
void storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
//! \return Number of actual blocks stored.
SizeType32 storeBlocks(std::vector<BlockKey> const& blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);

[[nodiscard]] bool verifyQueueIntegrity();

Expand All @@ -767,6 +780,12 @@ class WindowBlockManager
return 0;
}

//! \brief Return whether this window is SWA.
[[nodiscard]] bool isSWA() const
{
return mIsSWA;
}

private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
Expand Down Expand Up @@ -828,6 +847,8 @@ class WindowBlockManager
SizeType32 mSchedulingNumFreeBlocks;
// Number of tokens per one block
SizeType32 mTokensPerBlock;
// Whether this window is sliding window attention/full attention
bool mIsSWA;
// List of all blocks by idx
std::vector<BlockPtr> mAllBlocksById;
// Dummy block acting as root for BlockToken searches
Expand Down Expand Up @@ -880,7 +901,7 @@ class BlockManager

explicit BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth,
CudaStreamPtr stream, SizeType32 maxSequenceLength, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
Expand Down Expand Up @@ -1128,14 +1149,6 @@ class BlockManager
//! \brief Store newest block for reuse
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

[[nodiscard]] static bool isUseOneMoreBlock(
SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
{
bool const isCyclicWindowSize = maxSequenceLength.has_value() && maxSequenceLength.value() > windowSize;
bool const isBeamSearch = maxBeamWidth > 1;
return isCyclicWindowSize && isBeamSearch;
}

//! \brief Perform per-request bookkeeping
void refreshBlocks();

Expand All @@ -1154,12 +1167,17 @@ class BlockManager
//! \brief Update cache offsets for blocks initiated from sequence
void updateSequenceCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);

//! \brief Update cache offsets for last block
void updateLastCacheBlockOffsets(GenerationRequest& seq, SizeType32 windowSize);

//! \brief Update cache offsets for block at index
void updateCacheBlockOffsetsAtIdx(GenerationRequest& seq, SizeType32 windowSize, SizeType32 blockIdx);

//! \brief Add/detach block(s) to/from the sequence if needed
//! \details When we need a new block, we add it. For sliding window
//! attention (SWA), when a block goes out-of-window (OOW), we detach it
//! and store it if reuse is enabled. If this called in the first step of
//! the generation phase, we may detach more than a single block since
//! there may be more than one context block that goes OOW.
void adjustBlocksIfNeeded(GenerationRequest& sequence, bool isEnableBlockReuse);

private:
[[nodiscard]] WindowBlockManager const& windowManagerByLayer(SizeType32 layerIdx) const
{
Expand Down Expand Up @@ -1411,8 +1429,8 @@ class KVCacheManager : public BaseKVCacheManager
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
SizeType32 sinkTokenLength, CudaStreamPtr stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false,
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true,
Expand All @@ -1422,8 +1440,8 @@ class KVCacheManager : public BaseKVCacheManager
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false,
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true,
Expand All @@ -1433,8 +1451,8 @@ class KVCacheManager : public BaseKVCacheManager
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
SizeType32 sinkTokenLength, CudaStreamPtr stream, SizeType32 maxSequenceLength, bool enableBlockReuse = true,
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true,
Expand All @@ -1444,9 +1462,9 @@ class KVCacheManager : public BaseKVCacheManager
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
SizeType32 sinkTokenLength, int64_t stream, SizeType32 maxSequenceLength, bool enableBlockReuse = false,
bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, bool enablePartialReuse = true,
bool copyOnpartialReuse = true);

~KVCacheManager() override = default;

Expand Down
Loading