Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ class WindowBlockManager
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager);

Expand Down Expand Up @@ -798,8 +798,6 @@ class WindowBlockManager
// getPoolLayerIdx
std::unordered_map<SizeType32, SizeType32> mLayerToIndexWithinPool;

// Whether offloaded blocks should be onboarded before reuse.
bool mOnboardBlocks;
// Buffer manager
runtime::BufferManager mBufferManager;

Expand Down Expand Up @@ -860,7 +858,7 @@ class BlockManager
CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
SizeType32 sinkBubbleLength, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnPartialReuse = true,
Expand Down Expand Up @@ -1385,7 +1383,7 @@ class KVCacheManager : public BaseKVCacheManager
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
bool enableBlockReuse = false, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true,
Expand All @@ -1396,7 +1394,7 @@ class KVCacheManager : public BaseKVCacheManager
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
bool enableBlockReuse = false, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true,
Expand All @@ -1407,7 +1405,7 @@ class KVCacheManager : public BaseKVCacheManager
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
bool enableBlockReuse = true, CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
bool copyOnpartialReuse = true,
Expand All @@ -1418,8 +1416,8 @@ class KVCacheManager : public BaseKVCacheManager
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
bool enableBlockReuse = false, CacheType cacheType = CacheType::kSELF, bool enablePartialReuse = true,
bool copyOnpartialReuse = true);

~KVCacheManager() override = default;

Expand Down
7 changes: 1 addition & 6 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ class KvCacheConfig
std::optional<std::vector<SizeType32>> const& maxAttentionWindowVec = std::nullopt,
std::optional<SizeType32> const& sinkTokenLength = std::nullopt,
std::optional<FloatType> const& freeGpuMemoryFraction = std::nullopt,
std::optional<size_t> const& hostCacheSize = std::nullopt, bool onboardBlocks = true,
std::optional<size_t> const& hostCacheSize = std::nullopt,
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false,
Expand All @@ -1018,7 +1018,6 @@ class KvCacheConfig
[[nodiscard]] std::optional<FloatType> getFreeGpuMemoryFraction() const;
[[nodiscard]] std::optional<FloatType> getCrossKvCacheFraction() const;
[[nodiscard]] std::optional<size_t> getHostCacheSize() const;
[[nodiscard]] bool getOnboardBlocks() const;
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
[[nodiscard]] size_t getEventBufferMaxSize() const;
[[nodiscard]] bool getUseUvm() const;
Expand All @@ -1034,7 +1033,6 @@ class KvCacheConfig
void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction);
void setCrossKvCacheFraction(FloatType crossKvCacheFraction);
void setHostCacheSize(size_t hostCacheSize);
void setOnboardBlocks(bool onboardBlocks);
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
void setEventBufferMaxSize(size_t eventBufferMaxSize);
void setUseUvm(bool useUvm);
Expand Down Expand Up @@ -1078,9 +1076,6 @@ class KvCacheConfig
/// Having a secondary memory pool increases KV cache block reuse potential.
std::optional<size_t> mHostCacheSize;

/// @brief Controls whether offloaded blocks should be onboarded back into primary memory before being reused.
bool mOnboardBlocks;

/// @brief Only blocks with priority > mSecondaryOfflineMinPriority can be offloaded to secondary memory.
std::optional<RetentionPriority> mSecondaryOffloadMinPriority;

Expand Down
44 changes: 18 additions & 26 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
std::shared_ptr<runtime::CudaStream> stream, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
SizeType32 sinkBubbleLength, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
Expand Down Expand Up @@ -534,8 +534,8 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks...
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse,
copyOnPartialReuse, kvCacheConnectorManager);
cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, copyOnPartialReuse,
kvCacheConnectorManager);
}

auto const numAllPools = getNumPools();
Expand Down Expand Up @@ -575,15 +575,14 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize,
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: mDataType{dtype}
, mWindowSize{windowSize}
, mNumPrimaryBlocks{blocksInPrimaryPool}
, mNumSecondaryBlocks{blocksInSecondaryPool}
, mOnboardBlocks(onboardBlocks)
, mBufferManager{std::move(stream)}
, mSchedulingNumFreeBlocks{0}
, mTokensPerBlock{tokensPerBlock}
Expand Down Expand Up @@ -869,9 +868,7 @@ BlockPtr WindowBlockManager::getFreeBlock(
// 1. Block contains state (evidenced by presence of tokens)
// 2. Eviction policy indicated block can be offloaded
// 3. At least one free block in secondary memory
// 4. Onboarding is enabled (allowing block to be brought back into primary)
if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0
&& mOnboardBlocks)
if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0)
{
// If we're swapping a block to secondary memory, maintain the prior priority values.
mEvictionPolicy->claimBlock(block);
Expand Down Expand Up @@ -936,7 +933,7 @@ void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowS

void WindowBlockManager::onboardBlock(BlockPtr const& offloadBlock)
{
if (mOnboardBlocks && !offloadBlock->isPrimary())
if (!offloadBlock->isPrimary())
{
auto block = getFreeBlock();
mTransferManager->onboard(offloadBlock, block, mPools);
Expand All @@ -961,7 +958,7 @@ void BlockManager::offloadBlock(BlockPtr const& block, SizeType32 windowSize)

void WindowBlockManager::offloadBlock(BlockPtr const& block)
{
if (mOnboardBlocks && block->isPrimary())
if (block->isPrimary())
{
// Offload block in primary memory before repurposing
auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel));
Expand Down Expand Up @@ -1631,11 +1628,11 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse)
bool enableBlockReuse, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse)
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse)
enableBlockReuse, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse)
{
}

Expand All @@ -1644,15 +1641,14 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
bool enableBlockReuse, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
copyOnPartialReuse, kvCacheConnectorManager)
enableBlockReuse, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse, copyOnPartialReuse,
kvCacheConnectorManager)
{
}

Expand All @@ -1661,8 +1657,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
bool enableBlockReuse, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: mMaxBeamWidth(maxBeamWidth)
Expand All @@ -1673,8 +1668,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
, mSinkBlockTokenLength(mSinkBubbleLength + sinkTokenLength)
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
mSinkBubbleLength, cacheType, secondaryOffloadMinPriority, std::move(eventManager), enablePartialReuse,
copyOnPartialReuse, std::move(kvCacheConnectorManager))
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
{
Expand All @@ -1696,13 +1691,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
bool enableBlockReuse, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
std::move(stream), maxSequenceLength, enableBlockReuse, cacheType, secondaryOffloadMinPriority,
std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
{
}
Expand Down Expand Up @@ -2272,9 +2266,7 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfi
= static_cast<SizeType32>(allottedSecondaryMemBytes * windowSizeShare / cacheSizeBytesPerToken);
SizeType32 const blocksInSecondaryPool = std::max(0, maxTokensSecondary / tokensPerBlock);
TLLM_LOG_DEBUG(
"Number of blocks in KV cache secondary pool for windowSize %d: %d, onboard blocks to primary memory "
"before reuse: %s",
windowSize, blocksInSecondaryPool, config.getOnboardBlocks() ? "true" : "false");
"Number of blocks in KV cache secondary pool for windowSize %d: %d", windowSize, blocksInSecondaryPool);
return blocksInSecondaryPool;
};

Expand Down
10 changes: 2 additions & 8 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,6 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer

auto const& kvCacheConfig = executorConfig.getKvCacheConfig();

if (!kvCacheConfig.getOnboardBlocks())
{
TLLM_CHECK_WITH_INFO(
!mModelConfig.getPagedContextFMHA(), "KV cache blocks need to be onboarded if context FMHA.");
}

if (mModelConfig.getSpeculativeDecodingMode().isDraftTokensExternal())
{
TLLM_CHECK_WITH_INFO(kvCacheConfig.getEnableBlockReuse(),
Expand Down Expand Up @@ -688,8 +682,8 @@ std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::c

auto kvCacheManager = std::make_unique<KVCacheManager>(numKvHeadsPerLayer, sizePerHead, tokensPerBlock,
blocksPerWindow, getMaxNumSequences(), getMaxBeamWidth(), maxAttentionWindowVec, tempAttentionWindowInputs,
kvDtype, getSinkTokenLen(), mRuntime->getStreamPtr(), std::nullopt, enableBlockReuse,
kvCacheConfig.getOnboardBlocks(), kvCacheType, kvCacheConfig.getSecondaryOffloadMinPriority(),
kvDtype, getSinkTokenLen(), mRuntime->getStreamPtr(), std::nullopt, enableBlockReuse, kvCacheType,
kvCacheConfig.getSecondaryOffloadMinPriority(),
kvCacheConfig.getEventBufferMaxSize() > 0
? std::make_unique<kv_cache_manager::KVCacheEventManager>(kvCacheConfig.getEventBufferMaxSize())
: nullptr,
Expand Down
Loading