Skip to content

Commit 299ca54

Browse files
committed
[kv cache manager] No functional change intended, remove onboard block switch
Dead code elimination. The secondary block pool is derived when kv_cache_config::host_cache_size is specified. Whether we onboard/offload a kv cache block can be implicated from whether the manager has secondary block or not. The `onboardBlocks` toggle itself only adds complication. This commit removes it. Signed-off-by: eopXD <[email protected]>
1 parent b3ba3d9 commit 299ca54

File tree

15 files changed

+77
-172
lines changed

15 files changed

+77
-172
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ class WindowBlockManager
536536
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
537537
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
538538
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
539-
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
539+
CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
540540
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
541541
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager);
542542

@@ -798,8 +798,6 @@ class WindowBlockManager
798798
// getPoolLayerIdx
799799
std::unordered_map<SizeType32, SizeType32> mLayerToIndexWithinPool;
800800

801-
// Whether offloaded blocks should be onboarded before reuse.
802-
bool mOnboardBlocks;
803801
// Buffer manager
804802
runtime::BufferManager mBufferManager;
805803

@@ -860,7 +858,7 @@ class BlockManager
860858
CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth,
861859
std::vector<SizeType32> const& maxAttentionWindowVec,
862860
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
863-
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
861+
SizeType32 sinkBubbleLength, CacheType cacheType = CacheType::kSELF,
864862
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
865863
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
866864
bool copyOnPartialReuse = true,
@@ -1385,7 +1383,7 @@ class KVCacheManager : public BaseKVCacheManager
13851383
std::vector<SizeType32> const& maxAttentionWindowVec,
13861384
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
13871385
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
1388-
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
1386+
bool enableBlockReuse = false, CacheType cacheType = CacheType::kSELF,
13891387
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
13901388
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
13911389
bool copyOnpartialReuse = true,
@@ -1396,7 +1394,7 @@ class KVCacheManager : public BaseKVCacheManager
13961394
std::vector<SizeType32> const& maxAttentionWindowVec,
13971395
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
13981396
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
1399-
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
1397+
bool enableBlockReuse = false, CacheType cacheType = CacheType::kSELF,
14001398
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
14011399
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
14021400
bool copyOnpartialReuse = true,
@@ -1407,7 +1405,7 @@ class KVCacheManager : public BaseKVCacheManager
14071405
std::vector<SizeType32> const& maxAttentionWindowVec,
14081406
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
14091407
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
1410-
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
1408+
bool enableBlockReuse = true, CacheType cacheType = CacheType::kSELF,
14111409
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
14121410
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
14131411
bool copyOnpartialReuse = true,
@@ -1418,8 +1416,8 @@ class KVCacheManager : public BaseKVCacheManager
14181416
std::vector<SizeType32> const& maxAttentionWindowVec,
14191417
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
14201418
SizeType32 sinkTokenLength, int64_t stream, std::optional<SizeType32> maxSequenceLength,
1421-
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
1422-
bool enablePartialReuse = true, bool copyOnpartialReuse = true);
1419+
bool enableBlockReuse = false, CacheType cacheType = CacheType::kSELF, bool enablePartialReuse = true,
1420+
bool copyOnpartialReuse = true);
14231421

14241422
~KVCacheManager() override = default;
14251423

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ class KvCacheConfig
10011001
std::optional<std::vector<SizeType32>> const& maxAttentionWindowVec = std::nullopt,
10021002
std::optional<SizeType32> const& sinkTokenLength = std::nullopt,
10031003
std::optional<FloatType> const& freeGpuMemoryFraction = std::nullopt,
1004-
std::optional<size_t> const& hostCacheSize = std::nullopt, bool onboardBlocks = true,
1004+
std::optional<size_t> const& hostCacheSize = std::nullopt,
10051005
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
10061006
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
10071007
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false,
@@ -1018,7 +1018,6 @@ class KvCacheConfig
10181018
[[nodiscard]] std::optional<FloatType> getFreeGpuMemoryFraction() const;
10191019
[[nodiscard]] std::optional<FloatType> getCrossKvCacheFraction() const;
10201020
[[nodiscard]] std::optional<size_t> getHostCacheSize() const;
1021-
[[nodiscard]] bool getOnboardBlocks() const;
10221021
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
10231022
[[nodiscard]] size_t getEventBufferMaxSize() const;
10241023
[[nodiscard]] bool getUseUvm() const;
@@ -1034,7 +1033,6 @@ class KvCacheConfig
10341033
void setFreeGpuMemoryFraction(FloatType freeGpuMemoryFraction);
10351034
void setCrossKvCacheFraction(FloatType crossKvCacheFraction);
10361035
void setHostCacheSize(size_t hostCacheSize);
1037-
void setOnboardBlocks(bool onboardBlocks);
10381036
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
10391037
void setEventBufferMaxSize(size_t eventBufferMaxSize);
10401038
void setUseUvm(bool useUvm);
@@ -1078,9 +1076,6 @@ class KvCacheConfig
10781076
/// Having a secondary memory pool increases KV cache block reuse potential.
10791077
std::optional<size_t> mHostCacheSize;
10801078

1081-
/// @brief Controls whether offloaded blocks should be onboarded back into primary memory before being reused.
1082-
bool mOnboardBlocks;
1083-
10841079
/// @brief Only blocks with priority > mSecondaryOfflineMinPriority can be offloaded to secondary memory.
10851080
std::optional<RetentionPriority> mSecondaryOffloadMinPriority;
10861081

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
502502
std::shared_ptr<runtime::CudaStream> stream, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth,
503503
std::vector<SizeType32> const& maxAttentionWindowVec,
504504
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
505-
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
505+
SizeType32 sinkBubbleLength, CacheType cacheType,
506506
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
507507
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
508508
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
@@ -534,8 +534,8 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
534534
TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks...
535535
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
536536
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
537-
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse,
538-
copyOnPartialReuse, kvCacheConnectorManager);
537+
cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, copyOnPartialReuse,
538+
kvCacheConnectorManager);
539539
}
540540

541541
auto const numAllPools = getNumPools();
@@ -575,15 +575,14 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
575575
WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize,
576576
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
577577
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
578-
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
578+
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, CacheType cacheType,
579579
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
580580
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
581581
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
582582
: mDataType{dtype}
583583
, mWindowSize{windowSize}
584584
, mNumPrimaryBlocks{blocksInPrimaryPool}
585585
, mNumSecondaryBlocks{blocksInSecondaryPool}
586-
, mOnboardBlocks(onboardBlocks)
587586
, mBufferManager{std::move(stream)}
588587
, mSchedulingNumFreeBlocks{0}
589588
, mTokensPerBlock{tokensPerBlock}
@@ -869,9 +868,7 @@ BlockPtr WindowBlockManager::getFreeBlock(
869868
// 1. Block contains state (evidenced by presence of tokens)
870869
// 2. Eviction policy indicated block can be offloaded
871870
// 3. At least one free block in secondary memory
872-
// 4. Onboarding is enabled (allowing block to be brought back into primary)
873-
if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0
874-
&& mOnboardBlocks)
871+
if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0)
875872
{
876873
// If we're swapping a block to secondary memory, maintain the prior priority values.
877874
mEvictionPolicy->claimBlock(block);
@@ -936,7 +933,7 @@ void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowS
936933

937934
void WindowBlockManager::onboardBlock(BlockPtr const& offloadBlock)
938935
{
939-
if (mOnboardBlocks && !offloadBlock->isPrimary())
936+
if (!offloadBlock->isPrimary())
940937
{
941938
auto block = getFreeBlock();
942939
mTransferManager->onboard(offloadBlock, block, mPools);
@@ -961,7 +958,7 @@ void BlockManager::offloadBlock(BlockPtr const& block, SizeType32 windowSize)
961958

962959
void WindowBlockManager::offloadBlock(BlockPtr const& block)
963960
{
964-
if (mOnboardBlocks && block->isPrimary())
961+
if (block->isPrimary())
965962
{
966963
// Offload block in primary memory before repurposing
967964
auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel));
@@ -1631,11 +1628,11 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
16311628
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
16321629
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
16331630
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
1634-
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse)
1631+
bool enableBlockReuse, CacheType cacheType, bool enablePartialReuse, bool copyOnPartialReuse)
16351632
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
16361633
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
16371634
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
1638-
enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse)
1635+
enableBlockReuse, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse)
16391636
{
16401637
}
16411638

@@ -1644,15 +1641,14 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16441641
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
16451642
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
16461643
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
1647-
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
1648-
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1644+
bool enableBlockReuse, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
16491645
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
16501646
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16511647
: KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
16521648
maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
16531649
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
1654-
enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
1655-
copyOnPartialReuse, kvCacheConnectorManager)
1650+
enableBlockReuse, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse, copyOnPartialReuse,
1651+
kvCacheConnectorManager)
16561652
{
16571653
}
16581654

@@ -1661,8 +1657,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16611657
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
16621658
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
16631659
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
1664-
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
1665-
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1660+
bool enableBlockReuse, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
16661661
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
16671662
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
16681663
: mMaxBeamWidth(maxBeamWidth)
@@ -1673,8 +1668,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
16731668
, mSinkBlockTokenLength(mSinkBubbleLength + sinkTokenLength)
16741669
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
16751670
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
1676-
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
1677-
enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
1671+
mSinkBubbleLength, cacheType, secondaryOffloadMinPriority, std::move(eventManager), enablePartialReuse,
1672+
copyOnPartialReuse, std::move(kvCacheConnectorManager))
16781673
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
16791674
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
16801675
{
@@ -1696,13 +1691,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
16961691
SizeType32 maxBeamWidth, std::vector<SizeType32> const& maxAttentionWindowVec,
16971692
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
16981693
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
1699-
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
1700-
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
1694+
bool enableBlockReuse, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
17011695
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
17021696
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
17031697
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
17041698
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
1705-
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
1699+
std::move(stream), maxSequenceLength, enableBlockReuse, cacheType, secondaryOffloadMinPriority,
17061700
std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
17071701
{
17081702
}
@@ -2272,9 +2266,7 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfi
22722266
= static_cast<SizeType32>(allottedSecondaryMemBytes * windowSizeShare / cacheSizeBytesPerToken);
22732267
SizeType32 const blocksInSecondaryPool = std::max(0, maxTokensSecondary / tokensPerBlock);
22742268
TLLM_LOG_DEBUG(
2275-
"Number of blocks in KV cache secondary pool for windowSize %d: %d, onboard blocks to primary memory "
2276-
"before reuse: %s",
2277-
windowSize, blocksInSecondaryPool, config.getOnboardBlocks() ? "true" : "false");
2269+
"Number of blocks in KV cache secondary pool for windowSize %d: %d", windowSize, blocksInSecondaryPool);
22782270
return blocksInSecondaryPool;
22792271
};
22802272

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,6 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
182182

183183
auto const& kvCacheConfig = executorConfig.getKvCacheConfig();
184184

185-
if (!kvCacheConfig.getOnboardBlocks())
186-
{
187-
TLLM_CHECK_WITH_INFO(
188-
!mModelConfig.getPagedContextFMHA(), "KV cache blocks need to be onboarded if context FMHA.");
189-
}
190-
191185
if (mModelConfig.getSpeculativeDecodingMode().isDraftTokensExternal())
192186
{
193187
TLLM_CHECK_WITH_INFO(kvCacheConfig.getEnableBlockReuse(),
@@ -688,8 +682,8 @@ std::unique_ptr<kv_cache_manager::KVCacheManager> TrtGptModelInflightBatching::c
688682

689683
auto kvCacheManager = std::make_unique<KVCacheManager>(numKvHeadsPerLayer, sizePerHead, tokensPerBlock,
690684
blocksPerWindow, getMaxNumSequences(), getMaxBeamWidth(), maxAttentionWindowVec, tempAttentionWindowInputs,
691-
kvDtype, getSinkTokenLen(), mRuntime->getStreamPtr(), std::nullopt, enableBlockReuse,
692-
kvCacheConfig.getOnboardBlocks(), kvCacheType, kvCacheConfig.getSecondaryOffloadMinPriority(),
685+
kvDtype, getSinkTokenLen(), mRuntime->getStreamPtr(), std::nullopt, enableBlockReuse, kvCacheType,
686+
kvCacheConfig.getSecondaryOffloadMinPriority(),
693687
kvCacheConfig.getEventBufferMaxSize() > 0
694688
? std::make_unique<kv_cache_manager::KVCacheEventManager>(kvCacheConfig.getEventBufferMaxSize())
695689
: nullptr,

0 commit comments

Comments
 (0)