From 299ca542dbd6b48ffadd76f60f380408ec734e04 Mon Sep 17 00:00:00 2001 From: eopXD Date: Tue, 26 Aug 2025 07:04:41 -0700 Subject: [PATCH] [kv cache manager] No functional change intended, remove onboard block switch Dead code elimination. The secondary block pool is derived when kv_cache_config::host_cache_size is specified. Whether we onboard/offload a kv cache block can be implicated from whether the manager has secondary block or not. The `onboardBlocks` toggle itself only adds complication. This commit removes it. Signed-off-by: eopXD --- .../batch_manager/kvCacheManager.h | 16 +-- cpp/include/tensorrt_llm/executor/executor.h | 7 +- .../batch_manager/kvCacheManager.cpp | 44 +++---- .../trtGptModelInflightBatching.cpp | 10 +- cpp/tensorrt_llm/executor/kvCacheConfig.cpp | 18 +-- cpp/tensorrt_llm/executor/serialization.cpp | 7 +- .../nanobind/executor/executorConfig.cpp | 5 +- .../pybind/batch_manager/kvCacheManager.cpp | 4 +- .../pybind/executor/executorConfig.cpp | 5 +- .../batch_manager/kvCacheManagerTest.cpp | 121 ++++++------------ .../executor/serializeUtilsTest.cpp | 1 - tests/unittest/bindings/test_bindings_ut.py | 2 - .../bindings/test_executor_bindings.py | 6 - tests/unittest/llmapi/test_llm_args.py | 2 - .../llmapi/test_llm_kv_cache_events.py | 1 - 15 files changed, 77 insertions(+), 172 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 8940b160a15..4a1b19fe858 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -536,7 +536,7 @@ class WindowBlockManager std::vector const& managedLayers, std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, - bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, + CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager); @@ -798,8 +798,6 @@ class WindowBlockManager // getPoolLayerIdx std::unordered_map mLayerToIndexWithinPool; - // Whether offloaded blocks should be onboarded before reuse. - bool mOnboardBlocks; // Buffer manager runtime::BufferManager mBufferManager; @@ -860,7 +858,7 @@ class BlockManager CudaStreamPtr stream, std::optional maxSequenceLength, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF, + SizeType32 sinkBubbleLength, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, bool copyOnPartialReuse = true, @@ -1385,7 +1383,7 @@ class KVCacheManager : public BaseKVCacheManager std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, - bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, + bool enableBlockReuse = false, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, bool copyOnpartialReuse = true, @@ -1396,7 +1394,7 @@ class KVCacheManager : public BaseKVCacheManager std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, int64_t stream, std::optional maxSequenceLength, - bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, + bool enableBlockReuse = false, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, bool copyOnpartialReuse = true, @@ -1407,7 +1405,7 @@ class KVCacheManager : public BaseKVCacheManager std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, - bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, + bool enableBlockReuse = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, bool copyOnpartialReuse = true, @@ -1418,8 +1416,8 @@ class KVCacheManager : public BaseKVCacheManager std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, int64_t stream, std::optional maxSequenceLength, - bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, - bool enablePartialReuse = true, bool copyOnpartialReuse = true); + bool enableBlockReuse = false, CacheType cacheType = CacheType::kSELF, bool enablePartialReuse = true, + bool copyOnpartialReuse = true); ~KVCacheManager() override = default; diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 28c69074a3c..d8ea54f0968 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1001,7 +1001,7 @@ class KvCacheConfig std::optional> const& maxAttentionWindowVec = std::nullopt, std::optional const& sinkTokenLength = std::nullopt, std::optional const& freeGpuMemoryFraction = std::nullopt, - std::optional const& hostCacheSize = std::nullopt, bool onboardBlocks = true, + std::optional const& hostCacheSize = std::nullopt, std::optional const& crossKvCacheFraction = std::nullopt, std::optional secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0, bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false, @@ -1018,7 +1018,6 @@ class KvCacheConfig [[nodiscard]] std::optional getFreeGpuMemoryFraction() const; [[nodiscard]] std::optional getCrossKvCacheFraction() const; [[nodiscard]] std::optional getHostCacheSize() const; - [[nodiscard]] bool getOnboardBlocks() const; [[nodiscard]] std::optional getSecondaryOffloadMinPriority() const; [[nodiscard]] size_t getEventBufferMaxSize() const; [[nodiscard]] bool getUseUvm() const; @@ -1034,7 +1033,6 @@ class KvCacheConfig void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction); void setCrossKvCacheFraction(FloatType crossKvCacheFraction); void setHostCacheSize(size_t hostCacheSize); - void setOnboardBlocks(bool onboardBlocks); void setSecondaryOffloadMinPriority(std::optional secondaryOffloadMinPriority); void setEventBufferMaxSize(size_t eventBufferMaxSize); void setUseUvm(bool useUvm); @@ -1078,9 +1076,6 @@ class KvCacheConfig /// Having a secondary memory pool increases KV cache block reuse potential. std::optional mHostCacheSize; - /// @brief Controls whether offloaded blocks should be onboarded back into primary memory before being reused. - bool mOnboardBlocks; - /// @brief Only blocks with priority > mSecondaryOfflineMinPriority can be offloaded to secondary memory. std::optional mSecondaryOffloadMinPriority; diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 748fcbbe09d..0aac5f6a8fa 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -502,7 +502,7 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si std::shared_ptr stream, std::optional maxSequenceLength, SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, - SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType, + SizeType32 sinkBubbleLength, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager) @@ -534,8 +534,8 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks... mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer, sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream, - onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, - copyOnPartialReuse, kvCacheConnectorManager); + cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, copyOnPartialReuse, + kvCacheConnectorManager); } auto const numAllPools = getNumPools(); @@ -575,7 +575,7 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize, std::vector const& managedLayers, std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, - SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, + SizeType32 maxNumSequences, std::shared_ptr stream, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager) @@ -583,7 +583,6 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind , mWindowSize{windowSize} , mNumPrimaryBlocks{blocksInPrimaryPool} , mNumSecondaryBlocks{blocksInSecondaryPool} - , mOnboardBlocks(onboardBlocks) , mBufferManager{std::move(stream)} , mSchedulingNumFreeBlocks{0} , mTokensPerBlock{tokensPerBlock} @@ -869,9 +868,7 @@ BlockPtr WindowBlockManager::getFreeBlock( // 1. Block contains state (evidenced by presence of tokens) // 2. Eviction policy indicated block can be offloaded // 3. At least one free block in secondary memory - // 4. Onboarding is enabled (allowing block to be brought back into primary) - if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0 - && mOnboardBlocks) + if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0) { // If we're swapping a block to secondary memory, maintain the prior priority values. mEvictionPolicy->claimBlock(block); @@ -936,7 +933,7 @@ void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowS void WindowBlockManager::onboardBlock(BlockPtr const& offloadBlock) { - if (mOnboardBlocks && !offloadBlock->isPrimary()) + if (!offloadBlock->isPrimary()) { auto block = getFreeBlock(); mTransferManager->onboard(offloadBlock, block, mPools); @@ -961,7 +958,7 @@ void BlockManager::offloadBlock(BlockPtr const& block, SizeType32 windowSize) void WindowBlockManager::offloadBlock(BlockPtr const& block) { - if (mOnboardBlocks && block->isPrimary()) + if (block->isPrimary()) { // Offload block in primary memory before repurposing auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel)); @@ -1631,11 +1628,11 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, int64_t stream, std::optional maxSequenceLength, - bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse) + bool enableBlockReuse, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, - enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse) + enableBlockReuse, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse) { } @@ -1644,15 +1641,14 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, int64_t stream, std::optional maxSequenceLength, - bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, - std::optional secondaryOffloadMinPriority, + bool enableBlockReuse, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager) : KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, - enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse, - copyOnPartialReuse, kvCacheConnectorManager) + enableBlockReuse, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse, copyOnPartialReuse, + kvCacheConnectorManager) { } @@ -1661,8 +1657,7 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, - bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, - std::optional secondaryOffloadMinPriority, + bool enableBlockReuse, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager) : mMaxBeamWidth(maxBeamWidth) @@ -1673,8 +1668,8 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer , mSinkBlockTokenLength(mSinkBubbleLength + sinkTokenLength) , mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, - mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager), - enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager)) + mSinkBubbleLength, cacheType, secondaryOffloadMinPriority, std::move(eventManager), enablePartialReuse, + copyOnPartialReuse, std::move(kvCacheConnectorManager)) // disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case , mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse} { @@ -1696,13 +1691,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 maxBeamWidth, std::vector const& maxAttentionWindowVec, std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, - bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, - std::optional secondaryOffloadMinPriority, + bool enableBlockReuse, CacheType cacheType, std::optional secondaryOffloadMinPriority, std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse, std::shared_ptr kvCacheConnectorManager) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, - std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, + std::move(stream), maxSequenceLength, enableBlockReuse, cacheType, secondaryOffloadMinPriority, std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager)) { } @@ -2272,9 +2266,7 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfi = static_cast(allottedSecondaryMemBytes * windowSizeShare / cacheSizeBytesPerToken); SizeType32 const blocksInSecondaryPool = std::max(0, maxTokensSecondary / tokensPerBlock); TLLM_LOG_DEBUG( - "Number of blocks in KV cache secondary pool for windowSize %d: %d, onboard blocks to primary memory " - "before reuse: %s", - windowSize, blocksInSecondaryPool, config.getOnboardBlocks() ? "true" : "false"); + "Number of blocks in KV cache secondary pool for windowSize %d: %d", windowSize, blocksInSecondaryPool); return blocksInSecondaryPool; }; diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 08cb4d407c1..57751401050 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -182,12 +182,6 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr TrtGptModelInflightBatching::c auto kvCacheManager = std::make_unique(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, getMaxNumSequences(), getMaxBeamWidth(), maxAttentionWindowVec, tempAttentionWindowInputs, - kvDtype, getSinkTokenLen(), mRuntime->getStreamPtr(), std::nullopt, enableBlockReuse, - kvCacheConfig.getOnboardBlocks(), kvCacheType, kvCacheConfig.getSecondaryOffloadMinPriority(), + kvDtype, getSinkTokenLen(), mRuntime->getStreamPtr(), std::nullopt, enableBlockReuse, kvCacheType, + kvCacheConfig.getSecondaryOffloadMinPriority(), kvCacheConfig.getEventBufferMaxSize() > 0 ? std::make_unique(kvCacheConfig.getEventBufferMaxSize()) : nullptr, diff --git a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp index 1e83ba4b3a6..6a281f939ef 100644 --- a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp +++ b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp @@ -24,14 +24,12 @@ namespace tensorrt_llm::executor KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional const& maxTokens, std::optional> const& maxAttentionWindowVec, std::optional const& sinkTokenLength, std::optional const& freeGpuMemoryFraction, - std::optional const& hostCacheSize, bool onboardBlocks, - std::optional const& crossKvCacheFraction, std::optional secondaryOffloadMinPriority, - size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm, - SizeType32 attentionDpEventsGatherPeriodMs, + std::optional const& hostCacheSize, std::optional const& crossKvCacheFraction, + std::optional secondaryOffloadMinPriority, size_t eventBufferMaxSize, bool enablePartialReuse, + bool copyOnPartialReuse, bool useUvm, SizeType32 attentionDpEventsGatherPeriodMs, std::optional const& runtimeDefaults, uint64_t const& maxGpuTotalBytes) : mEnableBlockReuse(enableBlockReuse) , mHostCacheSize(hostCacheSize) - , mOnboardBlocks(onboardBlocks) , mSecondaryOffloadMinPriority(secondaryOffloadMinPriority) , mEventBufferMaxSize{eventBufferMaxSize} , mEnablePartialReuse{enablePartialReuse} @@ -117,11 +115,6 @@ std::optional KvCacheConfig::getHostCacheSize() const return mHostCacheSize; } -bool KvCacheConfig::getOnboardBlocks() const -{ - return mOnboardBlocks; -} - std::optional KvCacheConfig::getSecondaryOffloadMinPriority() const { return mSecondaryOffloadMinPriority; @@ -206,11 +199,6 @@ void KvCacheConfig::setHostCacheSize(size_t hostCacheSize) mHostCacheSize = hostCacheSize; } -void KvCacheConfig::setOnboardBlocks(bool onboardBlocks) -{ - mOnboardBlocks = onboardBlocks; -} - void KvCacheConfig::setSecondaryOffloadMinPriority(std::optional secondaryOffloadMinPriority) { mSecondaryOffloadMinPriority = secondaryOffloadMinPriority; diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index bba8d19e2f6..abf3fae303c 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -1164,7 +1164,6 @@ KvCacheConfig Serialization::deserializeKvCacheConfig(std::istream& is) auto sinkTokenLength = su::deserialize>(is); auto freeGpuMemoryFraction = su::deserialize>(is); auto hostCacheSize = su::deserialize>(is); - auto onboardBlocks = su::deserialize(is); auto crossKvCacheFraction = su::deserialize>(is); auto secondaryOffloadMinPriority = su::deserialize>(is); auto eventBufferMaxSize = su::deserialize(is); @@ -1172,8 +1171,8 @@ KvCacheConfig Serialization::deserializeKvCacheConfig(std::istream& is) auto attentionDpEventsGatherPeriodMs = su::deserialize(is); return KvCacheConfig{enableBlockReuse, maxTokens, maxAttentionWindowVec, sinkTokenLength, freeGpuMemoryFraction, - hostCacheSize, onboardBlocks, crossKvCacheFraction, secondaryOffloadMinPriority, eventBufferMaxSize, - enablePartialReuse, copyOnPartialReuse, useUvm, attentionDpEventsGatherPeriodMs}; + hostCacheSize, crossKvCacheFraction, secondaryOffloadMinPriority, eventBufferMaxSize, enablePartialReuse, + copyOnPartialReuse, useUvm, attentionDpEventsGatherPeriodMs}; } void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os) @@ -1186,7 +1185,6 @@ void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& su::serialize(kvCacheConfig.getSinkTokenLength(), os); su::serialize(kvCacheConfig.getFreeGpuMemoryFraction(), os); su::serialize(kvCacheConfig.getHostCacheSize(), os); - su::serialize(kvCacheConfig.getOnboardBlocks(), os); su::serialize(kvCacheConfig.getCrossKvCacheFraction(), os); su::serialize(kvCacheConfig.getSecondaryOffloadMinPriority(), os); su::serialize(kvCacheConfig.getEventBufferMaxSize(), os); @@ -1206,7 +1204,6 @@ size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig) totalSize += su::serializedSize(kvCacheConfig.getSinkTokenLength()); totalSize += su::serializedSize(kvCacheConfig.getFreeGpuMemoryFraction()); totalSize += su::serializedSize(kvCacheConfig.getHostCacheSize()); - totalSize += su::serializedSize(kvCacheConfig.getOnboardBlocks()); totalSize += su::serializedSize(kvCacheConfig.getCrossKvCacheFraction()); totalSize += su::serializedSize(kvCacheConfig.getSecondaryOffloadMinPriority()); totalSize += su::serializedSize(kvCacheConfig.getEventBufferMaxSize()); diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp index 0334eb14f6a..d809b904e55 100644 --- a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp @@ -110,8 +110,8 @@ void initConfigBindings(nb::module_& m) { return nb::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(), self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), - self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), - self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), + self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), self.getEventBufferMaxSize(), + self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), self.getAttentionDpEventsGatherPeriodMs(), self.getMaxGpuTotalBytes()); }; auto kvCacheConfigSetstate = [](tle::KvCacheConfig& self, nb::tuple const& state) @@ -151,7 +151,6 @@ void initConfigBindings(nb::module_& m) .def_prop_rw("free_gpu_memory_fraction", &tle::KvCacheConfig::getFreeGpuMemoryFraction, &tle::KvCacheConfig::setFreeGpuMemoryFraction) .def_prop_rw("host_cache_size", &tle::KvCacheConfig::getHostCacheSize, &tle::KvCacheConfig::setHostCacheSize) - .def_prop_rw("onboard_blocks", &tle::KvCacheConfig::getOnboardBlocks, &tle::KvCacheConfig::setOnboardBlocks) .def_prop_rw("cross_kv_cache_fraction", &tle::KvCacheConfig::getCrossKvCacheFraction, &tle::KvCacheConfig::setCrossKvCacheFraction) .def_prop_rw("secondary_offload_min_priority", &tle::KvCacheConfig::getSecondaryOffloadMinPriority, diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 320659a1d09..ddba77f8a2d 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -478,14 +478,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def(py::init const&, SizeType32, SizeType32, std::map> const&, SizeType32, SizeType32, std::vector const&, std::optional const&, - nvinfer1::DataType, SizeType32, bool, int64_t, bool, bool, tbk::CacheType, + nvinfer1::DataType, SizeType32, bool, int64_t, bool, tbk::CacheType, std::optional, std::shared_ptr, bool, bool, std::shared_ptr>(), py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"), py::arg("tokens_per_block"), py::arg("blocks_per_window"), py::arg("max_num_sequences"), py::arg("max_beam_width"), py::arg("max_attention_window_vec"), py::arg("temp_attention_window_inputs"), py::arg("dtype"), py::arg("sink_token_length"), py::arg("stream"), py::arg("max_sequence_length"), - py::arg("enable_block_reuse") = false, py::arg("onboard_blocks") = true, + py::arg("enable_block_reuse") = false, py::arg_v("cache_type", tbk::CacheType::kSELF, "bindings.internal.batch_manager.CacheType.SELF"), py::arg("secondary_offload_min_priority") = std::nullopt, py::arg("event_manager") = nullptr, py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index 74e2fe56c16..82ea3e4321c 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -102,8 +102,8 @@ void initConfigBindings(pybind11::module_& m) { return py::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(), self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), - self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), - self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), + self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), self.getEventBufferMaxSize(), + self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), self.getAttentionDpEventsGatherPeriodMs(), self.getMaxGpuTotalBytes()); }; auto kvCacheConfigSetstate = [](py::tuple const& state) @@ -144,7 +144,6 @@ void initConfigBindings(pybind11::module_& m) .def_property( "max_gpu_total_bytes", &tle::KvCacheConfig::getMaxGpuTotalBytes, &tle::KvCacheConfig::setMaxGpuTotalBytes) .def_property("host_cache_size", &tle::KvCacheConfig::getHostCacheSize, &tle::KvCacheConfig::setHostCacheSize) - .def_property("onboard_blocks", &tle::KvCacheConfig::getOnboardBlocks, &tle::KvCacheConfig::setOnboardBlocks) .def_property("cross_kv_cache_fraction", &tle::KvCacheConfig::getCrossKvCacheFraction, &tle::KvCacheConfig::setCrossKvCacheFraction) .def_property("secondary_offload_min_priority", &tle::KvCacheConfig::getSecondaryOffloadMinPriority, diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 0a52ae84852..48df8ef180c 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -115,7 +115,6 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr maxNumSequences = 8; auto const stream = std::make_shared(); - auto constexpr onboardBlocks = true; auto constexpr beamWidth = 8; auto constexpr numBlocksPerBeam = blocksInPrimaryPool / beamWidth; @@ -126,8 +125,7 @@ TEST_F(KVCacheManagerTest, BlockManagerTest) BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, stream, maxAttentionWindow, beamWidth, - std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, - onboardBlocks); + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0); blockManager.allocatePools(false); EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock); @@ -193,7 +191,6 @@ void runPartialCopyTest() auto constexpr blocksInSecondaryPool = 4; auto constexpr maxNumSequences = 8; auto const stream = std::make_shared(); - auto constexpr onboardBlocks = true; auto constexpr batchSize = 1; auto constexpr maxBlocksPerSeq = 10; @@ -213,7 +210,7 @@ void runPartialCopyTest() BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, stream, maxAttentionWindow, beamWidth, - std::vector{maxAttentionWindow}, std::nullopt, type, 0, onboardBlocks); + std::vector{maxAttentionWindow}, std::nullopt, type, 0); blockManager.allocatePools(false); auto oneLayerBlockSize = blockManager.getBlockSize(0); @@ -541,7 +538,6 @@ TEST_F(KVCacheManagerTest, FP4BlockScaleManagementTest) auto constexpr blocksInSecondaryPool = 16; auto constexpr numFp4EltsPerContainer = 2; auto constexpr vectorSize = 16; - auto constexpr onboardBlocks = true; auto const stream = std::make_shared(); auto constexpr beamWidth = 1; @@ -551,7 +547,7 @@ TEST_F(KVCacheManagerTest, FP4BlockScaleManagementTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kFP4, - false, stream, true, onboardBlocks); + false, stream, true); kvCacheManager.allocatePools(/*useUvm=*/false); @@ -579,7 +575,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr maxNumSequences = 8; auto const stream = std::make_shared(); - auto constexpr onboardBlocks = true; auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; auto constexpr beamWidth = 1; @@ -588,8 +583,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseTest) BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, stream, maxAttentionWindow, beamWidth, - std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, - onboardBlocks); + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0); blockManager.allocatePools(false); EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock); @@ -857,7 +851,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr maxNumSequences = 8; auto const stream = std::make_shared(); - auto constexpr onboardBlocks = true; auto constexpr numReturnSequences = 1; auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; @@ -867,8 +860,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, stream, maxAttentionWindow, beamWidth, - std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, - onboardBlocks); + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0); blockManager.allocatePools(false); EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock); @@ -1056,7 +1048,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr maxNumSequences = 8; auto const stream = std::make_shared(); - auto constexpr onboardBlocks = true; auto constexpr numReturnSequences = 1; auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; auto constexpr beamWidth = 1; @@ -1065,8 +1056,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, stream, maxAttentionWindow, beamWidth, - std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, - onboardBlocks); + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0); blockManager.allocatePools(false); EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock); @@ -1233,7 +1223,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr maxNumSequences = 8; auto const stream = std::make_shared(); - auto constexpr onboardBlocks = true; auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; auto constexpr beamWidth = 1; @@ -1242,8 +1231,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, stream, maxAttentionWindow, beamWidth, - std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, - onboardBlocks); + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0); blockManager.allocatePools(false); EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock); @@ -1475,7 +1463,6 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr maxNumSequences = 8; auto const stream = std::make_shared(); - auto constexpr onboardBlocks = true; auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; auto constexpr beamWidth = 1; @@ -1484,8 +1471,7 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, stream, maxAttentionWindow, beamWidth, - std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, - onboardBlocks); + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0); blockManager.allocatePools(false); EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock); @@ -1696,7 +1682,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerPerRequestStatsTest) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 16; auto constexpr blocksInSecondaryPool = 0; - auto constexpr onboardBlocks = true; auto const stream = std::make_shared(); auto constexpr beamWidth = 1; @@ -1710,7 +1695,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerPerRequestStatsTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks); + 0, stream, std::nullopt, true); kvCacheManager.allocatePools(false); auto inputTokens = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8}); @@ -1756,7 +1741,6 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr maxNumSequences = 4; auto const stream = std::make_shared(); - auto constexpr onboardBlocks = true; auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; auto constexpr beamWidth = 1; @@ -1765,8 +1749,7 @@ TEST_F(KVCacheManagerTest, BlockManagerBlockPriorityTest) BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, stream, maxAttentionWindow, beamWidth, - std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, - onboardBlocks); + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0); blockManager.allocatePools(false); EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock); @@ -1856,7 +1839,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeBlockPriorityTest) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 0; - auto constexpr onboardBlocks = true; auto const stream = std::make_shared(); auto constexpr beamWidth = 1; @@ -1870,7 +1852,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeBlockPriorityTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks); + 0, stream, std::nullopt, true); kvCacheManager.allocatePools(false); auto const& blockManager = kvCacheManager.getBlockManager(); @@ -1962,7 +1944,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerTimedEvictionTest) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 0; - auto constexpr onboardBlocks = true; auto const stream = std::make_shared(); auto constexpr beamWidth = 1; @@ -1976,7 +1957,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerTimedEvictionTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks); + 0, stream, std::nullopt, true); kvCacheManager.allocatePools(false); auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); @@ -2030,7 +2011,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 0; - auto constexpr onboardBlocks = true; auto const stream = std::make_shared(); auto constexpr beamWidth = 1; @@ -2044,7 +2024,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerDecodeTimedEvictionTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks); + 0, stream, std::nullopt, true); kvCacheManager.allocatePools(false); { auto inputTokens0 = std::make_shared(VecTokens{1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); @@ -2107,7 +2087,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSecondaryBlockPrimaryChildTest) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 4; auto constexpr blocksInSecondaryPool = 4; - auto constexpr onboardBlocks = true; auto const stream = std::make_shared(); auto constexpr beamWidth = 1; @@ -2121,7 +2100,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerSecondaryBlockPrimaryChildTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - false, stream, true, onboardBlocks); + false, stream, true); kvCacheManager.allocatePools(false); auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); @@ -2183,7 +2162,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockTest) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 4; auto constexpr blocksInSecondaryPool = 0; - auto constexpr onboardBlocks = true; auto const stream = std::make_shared(); auto constexpr beamWidth = 1; @@ -2196,7 +2174,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - false, stream, true, onboardBlocks); + false, stream, true); kvCacheManager.allocatePools(false); auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3}); @@ -2260,7 +2238,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 4; auto constexpr blocksInSecondaryPool = 1; - auto constexpr onboardBlocks = true; auto const stream = std::make_shared(); auto constexpr beamWidth = 1; @@ -2275,7 +2252,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerLeafBlockWithDependentTest) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - false, stream, true, onboardBlocks); + false, stream, true); kvCacheManager.allocatePools(false); // Create sequence with one block worth of context tokens @@ -2384,7 +2361,6 @@ TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerAllocationTest) auto constexpr enableBlockReuse = false; auto constexpr useUvm = false; - auto constexpr onboardBlocks = true; auto const homogeneousLayers = GetParam(); auto const granularity = tensorrt_llm::common::getAllocationGranularity(); @@ -2394,11 +2370,10 @@ TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerAllocationTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse) : KVCacheManager(std::vector(numLayers, numHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, - std::nullopt, nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, - onboardBlocks); + std::nullopt, nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse); auto const& blockManager = kvCacheManager.getBlockManager(); auto const& bufferManager = blockManager.getBufferManager(theOnlyWindowSize(kvCacheManager)); @@ -2446,7 +2421,6 @@ TEST_P(KVCacheManagerTest, KVCacheManagerTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr enableBlockReuse = false; - auto constexpr onboardBlocks = true; auto const homogeneousLayers = GetParam(); auto const expectedNumPools = homogeneousLayers ? 1 : static_cast(expectedHeadsPerPool.size()); @@ -2455,10 +2429,10 @@ TEST_P(KVCacheManagerTest, KVCacheManagerTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse) : KVCacheManager(numHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks); + sinkTokenLength, stream, std::nullopt, enableBlockReuse); kvCacheManager.allocatePools(false); EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, maxBlocksPerSeq); @@ -2595,7 +2569,6 @@ TEST_P(KVCacheManagerTest, KVCacheManagerRewindTokensTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr enableBlockReuse = false; - auto constexpr onboardBlocks = true; auto const homogeneousLayers = GetParam(); auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {totalNumBlocks, blocksInSecondaryPool}}}; @@ -2603,11 +2576,10 @@ TEST_P(KVCacheManagerTest, KVCacheManagerRewindTokensTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse) : KVCacheManager(std::vector(numLayers, numHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, - std::nullopt, nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, - onboardBlocks); + std::nullopt, nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse); kvCacheManager.allocatePools(false); EXPECT_EQ(kvCacheManager.getTokensPerBlock(), tokensPerBlock); @@ -2683,7 +2655,6 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr enableBlockReuse = false; - auto constexpr onboardBlocks = true; auto const homogeneousLayers = GetParam(); auto const expectedNumPools = homogeneousLayers ? 1 : static_cast(expectedHeadsPerPool.size()); @@ -2692,10 +2663,10 @@ TEST_P(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse) : KVCacheManager(numHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks); + sinkTokenLength, stream, std::nullopt, enableBlockReuse); kvCacheManager.allocatePools(false); EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, maxBlocksPerSeq); @@ -2804,13 +2775,12 @@ TEST_F(KVCacheManagerTest, KVCacheManagerMaxAttentionWindowWithReuseTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr enableBlockReuse = true; - auto constexpr onboardBlocks = true; auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks); + nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse); kvCacheManager.allocatePools(false); auto const& blockManager = kvCacheManager.getBlockManager(); @@ -2951,14 +2921,13 @@ TEST_F(KVCacheManagerTest, KVCacheManagerVariableWindowAttentionWithReuseTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr enableBlockReuse = true; - auto constexpr onboardBlocks = true; auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}, {minAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, std::nullopt, dtype, sinkTokenLength, stream, std::nullopt, - enableBlockReuse, onboardBlocks); + enableBlockReuse); kvCacheManager.allocatePools(false); auto const& blockManager = kvCacheManager.getBlockManager(); @@ -3073,7 +3042,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 2; - auto constexpr onboardBlocks = true; auto constexpr dtype = nvinfer1::DataType::kHALF; auto const stream = std::make_shared(); @@ -3088,8 +3056,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStream) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, dtype, 0, stream, - std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, - std::make_unique(1024)); + std::nullopt, true, CacheType::kSELF, std::nullopt, std::make_unique(1024)); kvCacheManager.allocatePools(false); auto events = getEvents(kvCacheManager); @@ -3229,7 +3196,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamOverflow) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 2; - auto constexpr onboardBlocks = true; auto constexpr dtype = nvinfer1::DataType::kHALF; auto const stream = std::make_shared(); @@ -3244,8 +3210,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamOverflow) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, dtype, 0, stream, - std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, - std::make_unique(1)); + std::nullopt, true, CacheType::kSELF, std::nullopt, std::make_unique(1)); kvCacheManager.allocatePools(false); auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); @@ -3287,7 +3252,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamPriority) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 2; - auto constexpr onboardBlocks = true; auto constexpr dtype = nvinfer1::DataType::kHALF; auto const stream = std::make_shared(); @@ -3302,8 +3266,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamPriority) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, dtype, 0, stream, - std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, - std::make_unique(1024)); + std::nullopt, true, CacheType::kSELF, std::nullopt, std::make_unique(1024)); kvCacheManager.allocatePools(false); auto inputTokens0 = std::make_shared(VecTokens{0, 1, 2, 3, 4, 5, 6, 7}); @@ -3362,7 +3325,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamBlocking) auto constexpr maxNumSequences = 8; auto constexpr blocksInPrimaryPool = 8; auto constexpr blocksInSecondaryPool = 2; - auto constexpr onboardBlocks = true; auto constexpr dtype = nvinfer1::DataType::kHALF; auto const stream = std::make_shared(); @@ -3377,13 +3339,13 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamBlocking) KVCacheManager kvCacheManagerTest(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, dtype, 0, - stream, std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt); + stream, std::nullopt, true, CacheType::kSELF, std::nullopt); EXPECT_EQ(getEvents(kvCacheManagerTest).size(), 0); KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - 0, stream, std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, + 0, stream, std::nullopt, true, CacheType::kSELF, std::nullopt, std::make_unique(1024)); kvCacheManager.allocatePools(false); @@ -3415,7 +3377,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamWindowSize) auto constexpr maxNumSequences = 8; auto blocksInPool = std::vector{8, 2}; auto blocksInSlidingWindowPool = std::vector{4, 2}; - auto constexpr onboardBlocks = true; auto constexpr dtype = nvinfer1::DataType::kHALF; auto const stream = std::make_shared(); @@ -3432,8 +3393,7 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamWindowSize) KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, beamWidth, std::vector{maxAttentionWindow, slidingWindow}, std::nullopt, dtype, 0, - stream, std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt, - std::make_unique(1024)); + stream, std::nullopt, true, CacheType::kSELF, std::nullopt, std::make_unique(1024)); kvCacheManager.allocatePools(false); auto events = getEvents(kvCacheManager); @@ -3534,7 +3494,6 @@ TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerSinkTokenLengthTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr enableBlockReuse = false; - auto constexpr onboardBlocks = true; auto const homogeneousLayers = GetParam(); auto const expectedNumPools = homogeneousLayers ? 1 : static_cast(expectedHeadsPerPool.size()); @@ -3544,10 +3503,10 @@ TEST_P(KVCacheManagerTest, DISABLED_KVCacheManagerSinkTokenLengthTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, maxSequenceLength, enableBlockReuse) : KVCacheManager(numHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - sinkTokenLength, stream, maxSequenceLength, enableBlockReuse, onboardBlocks); + sinkTokenLength, stream, maxSequenceLength, enableBlockReuse); kvCacheManager.allocatePools(false); EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, maxBlocksPerSeq); @@ -3673,7 +3632,6 @@ TEST_P(KVCacheManagerTest, KVCacheManagerBatchTest) auto constexpr blocksInSecondaryPool = 0; auto constexpr enableBlockReuse = false; - auto constexpr onboardBlocks = true; auto const homogeneousLayers = GetParam(); auto const expectedNumPools = homogeneousLayers ? 1 : static_cast(expectedHeadsPerPool.size()); @@ -3682,10 +3640,10 @@ TEST_P(KVCacheManagerTest, KVCacheManagerBatchTest) KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, enableBlockReuse) : KVCacheManager(numHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, - sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks); + sinkTokenLength, stream, std::nullopt, enableBlockReuse); kvCacheManager.allocatePools(false); EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, maxBlocksPerSeq); @@ -3812,19 +3770,16 @@ void testNeededBlocksOneStep(bool kv_cache_block_reuse, int beamWidth, int draft auto constexpr maxAttentionWindow = 46; auto constexpr totalNumBlocks = maxNumSequences * maxBlocksPerSeq; auto constexpr blocksInSecondaryPool = 0; - auto constexpr onboardBlocks = true; auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {totalNumBlocks, blocksInSecondaryPool}}}; KVCacheManager kvCacheManager = homogeneousLayers ? KVCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, kv_cache_block_reuse, - onboardBlocks) + nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, kv_cache_block_reuse) : KVCacheManager(numHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, - nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, kv_cache_block_reuse, - onboardBlocks); + nvinfer1::DataType::kHALF, sinkTokenLength, stream, std::nullopt, kv_cache_block_reuse); kvCacheManager.allocatePools(false); EXPECT_EQ(kvCacheManager.getOffsetTableDimensions().maxBlocksPerSeq, diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index 1dad1fa2bbb..80982b47c57 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -485,7 +485,6 @@ TEST(SerializeUtilsTest, KvCacheConfig) EXPECT_EQ(kvCacheConfig.getSinkTokenLength(), kvCacheConfig2.getSinkTokenLength()); EXPECT_EQ(kvCacheConfig.getFreeGpuMemoryFraction(), kvCacheConfig2.getFreeGpuMemoryFraction()); EXPECT_EQ(kvCacheConfig.getHostCacheSize(), kvCacheConfig2.getHostCacheSize()); - EXPECT_EQ(kvCacheConfig.getOnboardBlocks(), kvCacheConfig2.getOnboardBlocks()); EXPECT_EQ(kvCacheConfig.getCrossKvCacheFraction(), kvCacheConfig2.getCrossKvCacheFraction()); EXPECT_EQ(kvCacheConfig.getSecondaryOffloadMinPriority(), kvCacheConfig2.getSecondaryOffloadMinPriority()); EXPECT_EQ(kvCacheConfig.getEventBufferMaxSize(), kvCacheConfig2.getEventBufferMaxSize()); diff --git a/tests/unittest/bindings/test_bindings_ut.py b/tests/unittest/bindings/test_bindings_ut.py index f049a4437cb..4a63abd172c 100644 --- a/tests/unittest/bindings/test_bindings_ut.py +++ b/tests/unittest/bindings/test_bindings_ut.py @@ -475,8 +475,6 @@ def test_KvCache_events_binding(): max_sequence_length, 'enable_block_reuse': True, - 'onboard_blocks': - False, 'cache_type': _tb.internal.batch_manager.CacheType.SELF, 'event_manager': diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index 8556cf54d69..6b4b60d925d 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -1308,7 +1308,6 @@ def test_kv_cache_config(): assert config.free_gpu_memory_fraction is None assert config.cross_kv_cache_fraction is None assert config.host_cache_size is None - assert config.onboard_blocks == True assert config.secondary_offload_min_priority is None assert config.event_buffer_max_size == 0 assert config.enable_partial_reuse == True @@ -1323,7 +1322,6 @@ def test_kv_cache_config(): config.free_gpu_memory_fraction = 0.5 config.cross_kv_cache_fraction = 0.5 config.host_cache_size = 4 - config.onboard_blocks = False config.secondary_offload_min_priority = 50 config.event_buffer_max_size = 1024 config.enable_partial_reuse = False @@ -1337,7 +1335,6 @@ def test_kv_cache_config(): assert config.free_gpu_memory_fraction == 0.5 assert config.cross_kv_cache_fraction == 0.5 assert config.host_cache_size == 4 - assert config.onboard_blocks == False assert config.secondary_offload_min_priority == 50 assert config.event_buffer_max_size == 1024 assert config.enable_partial_reuse == False @@ -1353,7 +1350,6 @@ def test_kv_cache_config(): "free_gpu_memory_fraction": 0.5, "cross_kv_cache_fraction": 0.5, "host_cache_size": 1024, - "onboard_blocks": False, "event_buffer_max_size": 2048, "enable_partial_reuse": True, "copy_on_partial_reuse": False, @@ -2400,7 +2396,6 @@ def test_kv_cache_config_pickle(): config.free_gpu_memory_fraction = 0.3 config.cross_kv_cache_fraction = 0.5 config.host_cache_size = 4 - config.onboard_blocks = False config.secondary_offload_min_priority = 50 config.event_buffer_max_size = 1024 config.enable_partial_reuse = False @@ -2414,7 +2409,6 @@ def test_kv_cache_config_pickle(): assert config.free_gpu_memory_fraction == config_copy.free_gpu_memory_fraction assert config.cross_kv_cache_fraction == config_copy.cross_kv_cache_fraction assert config.host_cache_size == config_copy.host_cache_size - assert config.onboard_blocks == config_copy.onboard_blocks assert config.secondary_offload_min_priority == config_copy.secondary_offload_min_priority assert config.event_buffer_max_size == config_copy.event_buffer_max_size assert config.enable_partial_reuse == config_copy.enable_partial_reuse diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index a01d7f591f3..b8af08413f2 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -157,7 +157,6 @@ def test_KvCacheConfig_declaration(): sink_token_length=32, free_gpu_memory_fraction=0.5, host_cache_size=1024, - onboard_blocks=True, cross_kv_cache_fraction=0.5, secondary_offload_min_priority=1, event_buffer_max_size=0, @@ -172,7 +171,6 @@ def test_KvCacheConfig_declaration(): assert pybind_config.sink_token_length == 32 assert pybind_config.free_gpu_memory_fraction == 0.5 assert pybind_config.host_cache_size == 1024 - assert pybind_config.onboard_blocks == True assert pybind_config.cross_kv_cache_fraction == 0.5 assert pybind_config.secondary_offload_min_priority == 1 assert pybind_config.event_buffer_max_size == 0 diff --git a/tests/unittest/llmapi/test_llm_kv_cache_events.py b/tests/unittest/llmapi/test_llm_kv_cache_events.py index db90a34413e..590091d29f6 100644 --- a/tests/unittest/llmapi/test_llm_kv_cache_events.py +++ b/tests/unittest/llmapi/test_llm_kv_cache_events.py @@ -21,7 +21,6 @@ global_kvcache_config = KvCacheConfig(free_gpu_memory_fraction=0.4, event_buffer_max_size=1024, enable_block_reuse=True, - onboard_blocks=True, max_tokens=256)