Skip to content

Commit e658165

Browse files
committed
feat: Add max_free_gpu_memory_size support for KV cache configuration
- Introduced max_free_gpu_memory_size to manage GPU memory allocation for KV cache. - Updated KvCacheConfig and related methods to handle the new parameter. - Modified estimation logic in KvCacheCreator to utilize max_free_gpu_memory_size for VSWA cases. - Adjusted resource management to ensure compatibility with the new memory allocation strategy. Signed-off-by: qixiang-99 <[email protected]>
1 parent c191b38 commit e658165

File tree

7 files changed

+71
-16
lines changed

7 files changed

+71
-16
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,8 @@ class KvCacheConfig
10011001
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
10021002
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
10031003
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false,
1004-
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt);
1004+
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt,
1005+
std::optional<uint64_t> const& maxFreeGpuMemorySize = std::nullopt);
10051006

10061007
[[nodiscard]] bool getEnableBlockReuse() const;
10071008
[[nodiscard]] bool getEnablePartialReuse() const;
@@ -1016,6 +1017,7 @@ class KvCacheConfig
10161017
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
10171018
[[nodiscard]] size_t getEventBufferMaxSize() const;
10181019
[[nodiscard]] bool getUseUvm() const;
1020+
[[nodiscard]] std::optional<uint64_t> getMaxFreeGpuMemorySize() const;
10191021

10201022
void setEnableBlockReuse(bool enableBlockReuse);
10211023
void setEnablePartialReuse(bool enablePartialReuse);
@@ -1030,6 +1032,7 @@ class KvCacheConfig
10301032
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
10311033
void setEventBufferMaxSize(size_t eventBufferMaxSize);
10321034
void setUseUvm(bool useUvm);
1035+
void setMaxFreeGpuMemorySize(uint64_t maxFreeGpuMemorySize);
10331036

10341037
void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults);
10351038

@@ -1085,6 +1088,12 @@ class KvCacheConfig
10851088

10861089
/// @brief Whether to use UVM for the KV cache.
10871090
bool mUseUvm;
1091+
1092+
/// @brief The maximum size in bytes of GPU memory that can be allocated for the KV cache.
1093+
/// This is only used for VSWA case for now as a alternative to mMaxTokens.
1094+
/// If both mMaxFreeGpuMemorySize and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will
1095+
/// be allocated.
1096+
std::optional<uint64_t> mMaxFreeGpuMemorySize;
10881097
};
10891098

10901099
/// @brief Configuration class for the runtime perf knobs

cpp/tensorrt_llm/executor/kvCacheConfig.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
2727
std::optional<size_t> const& hostCacheSize, bool onboardBlocks,
2828
std::optional<FloatType> const& crossKvCacheFraction, std::optional<RetentionPriority> secondaryOffloadMinPriority,
2929
size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm,
30-
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults)
30+
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults,
31+
std::optional<uint64_t> const& maxFreeGpuMemorySize)
3132
: mEnableBlockReuse(enableBlockReuse)
3233
, mHostCacheSize(hostCacheSize)
3334
, mOnboardBlocks(onboardBlocks)
@@ -61,6 +62,10 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional<SizeType32> co
6162
{
6263
fillEmptyFieldsFromRuntimeDefaults(runtimeDefaults.value());
6364
}
65+
if (maxFreeGpuMemorySize)
66+
{
67+
setMaxFreeGpuMemorySize(maxFreeGpuMemorySize.value());
68+
}
6469
}
6570

6671
bool KvCacheConfig::getEnableBlockReuse() const
@@ -128,6 +133,11 @@ bool KvCacheConfig::getUseUvm() const
128133
return mUseUvm;
129134
}
130135

136+
std::optional<uint64_t> KvCacheConfig::getMaxFreeGpuMemorySize() const
137+
{
138+
return mMaxFreeGpuMemorySize;
139+
}
140+
131141
void KvCacheConfig::setEnableBlockReuse(bool enableBlockReuse)
132142
{
133143
mEnableBlockReuse = enableBlockReuse;
@@ -207,6 +217,11 @@ void KvCacheConfig::setUseUvm(bool useUvm)
207217
mUseUvm = useUvm;
208218
}
209219

220+
void KvCacheConfig::setMaxFreeGpuMemorySize(uint64_t maxFreeGpuMemorySize)
221+
{
222+
mMaxFreeGpuMemorySize = maxFreeGpuMemorySize;
223+
}
224+
210225
void KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults)
211226
{
212227
if (!mMaxAttentionWindowVec && runtimeDefaults.maxAttentionWindowVec)

cpp/tensorrt_llm/pybind/executor/executorConfig.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,14 @@ void initConfigBindings(pybind11::module_& m)
121121
.def(py::init<bool, std::optional<SizeType32> const&, std::optional<std::vector<SizeType32>> const&,
122122
std::optional<SizeType32> const&, std::optional<float> const&, std::optional<size_t> const&, bool,
123123
std::optional<float> const&, std::optional<tle::RetentionPriority>, size_t const&, bool, bool, bool,
124-
std::optional<RuntimeDefaults> const&>(),
124+
std::optional<RuntimeDefaults> const&, std::optional<uint64_t> const&>(),
125125
py::arg("enable_block_reuse") = true, py::arg("max_tokens") = py::none(),
126126
py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none(),
127127
py::arg("free_gpu_memory_fraction") = py::none(), py::arg("host_cache_size") = py::none(),
128128
py::arg("onboard_blocks") = true, py::arg("cross_kv_cache_fraction") = py::none(),
129129
py::arg("secondary_offload_min_priority") = py::none(), py::arg("event_buffer_max_size") = 0, py::kw_only(),
130130
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, py::arg("use_uvm") = false,
131-
py::arg("runtime_defaults") = py::none())
131+
py::arg("runtime_defaults") = py::none(), py::arg("max_free_gpu_memory_size") = py::none())
132132
.def_property(
133133
"enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse)
134134
.def_property("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens)
@@ -138,6 +138,8 @@ void initConfigBindings(pybind11::module_& m)
138138
"sink_token_length", &tle::KvCacheConfig::getSinkTokenLength, &tle::KvCacheConfig::setSinkTokenLength)
139139
.def_property("free_gpu_memory_fraction", &tle::KvCacheConfig::getFreeGpuMemoryFraction,
140140
&tle::KvCacheConfig::setFreeGpuMemoryFraction)
141+
.def_property("max_free_gpu_memory_size", &tle::KvCacheConfig::getMaxFreeGpuMemorySize,
142+
&tle::KvCacheConfig::setMaxFreeGpuMemorySize)
141143
.def_property("host_cache_size", &tle::KvCacheConfig::getHostCacheSize, &tle::KvCacheConfig::setHostCacheSize)
142144
.def_property("onboard_blocks", &tle::KvCacheConfig::getOnboardBlocks, &tle::KvCacheConfig::setOnboardBlocks)
143145
.def_property("cross_kv_cache_fraction", &tle::KvCacheConfig::getCrossKvCacheFraction,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,13 @@ def _get_free_gpu_memory_fraction(self) -> float:
9494
fraction = 0.9
9595
return fraction
9696

97-
def _cal_max_tokens(self, peak_memory, total_gpu_memory, fraction,
98-
alloc_kv_tokens: int) -> int:
97+
def _cal_max_tokens_and_memory(self, peak_memory, total_gpu_memory,
98+
fraction,
99+
alloc_kv_tokens: int) -> tuple[int, int]:
100+
"""
101+
Calculate the max KV cache capacity in max_tokens and max_free_gpu_memory_size.
102+
For VSWA case, we use max_free_gpu_memory_size instead of max_tokens.
103+
"""
99104
model_config = self._model_engine.model.model_config
100105
mapping = self._mapping
101106
kv_size_per_token = self._get_cache_size_per_token(
@@ -115,7 +120,7 @@ def _cal_max_tokens(self, peak_memory, total_gpu_memory, fraction,
115120
)
116121
max_tokens = int((available_kv_mem) // kv_size_per_token)
117122
max_tokens = max(max_tokens, 0)
118-
return max_tokens
123+
return max_tokens, int(available_kv_mem)
119124

120125
def _create_dummy_context_requests(
121126
self, input_seq_len: int) -> List[trtllm.Request]:
@@ -185,8 +190,9 @@ def try_prepare_estimation(self) -> bool:
185190
)
186191
return estimating_kv_cache
187192

188-
def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
193+
def estimate_max_tokens_or_memory(self, py_executor: PyExecutor) -> None:
189194
"""Perform KV cache capacity estimation.
195+
NOTE: for VSWA case, we calculate and set kv cache memory instead of using max_tokens in kv_cache_config.
190196
191197
This updates `kv_cache_config`.
192198
"""
@@ -255,16 +261,29 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
255261
kv_stats = py_executor.resource_manager.resource_managers.get(
256262
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()
257263

258-
kv_cache_max_tokens = self._cal_max_tokens(
264+
kv_cache_max_tokens, kv_cache_max_memory = self._cal_max_tokens_and_memory(
259265
peak_memory, total_gpu_memory, fraction,
260266
kv_stats.max_num_blocks * kv_stats.tokens_per_block)
261267

262268
if self._max_kv_tokens_in is not None:
263269
kv_cache_max_tokens = min(kv_cache_max_tokens,
264270
self._max_kv_tokens_in)
265271

266-
logger.info(f"Estimated max tokens in KV cache : {kv_cache_max_tokens}")
267-
executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens
272+
if executor_config.kv_cache_config.max_attention_window is not None:
273+
# NOTE: for VSWA case, we calculate and set kv cache memory instead of using max_tokens in kv_cache_config.
274+
assert kv_cache_max_memory is not None, "kv_cache_max_memory should not be None for VSWA case"
275+
executor_config.kv_cache_config.max_free_gpu_memory_size = int(
276+
kv_cache_max_memory)
277+
logger.debug(
278+
f"For VSWA case, we set max_free_gpu_memory_size instead of max_tokens in kv_cache_config."
279+
)
280+
logger.info(
281+
f"Estimated max memory in KV cache : {kv_cache_max_memory / (GB):.2f} GiB"
282+
)
283+
else:
284+
logger.info(
285+
f"Estimated max tokens in KV cache : {kv_cache_max_tokens}")
286+
executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens
268287

269288
def _create_kv_cache_manager(
270289
self, model_engine: PyTorchModelEngine) -> KVCacheManager:

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def create_py_executor(
393393
assert kv_cache_creator is not None
394394
with mem_monitor.observe_creation_stage(
395395
_ExecutorCreationStage.MODEL_EXTRA):
396-
kv_cache_creator.estimate_max_tokens(py_executor)
396+
kv_cache_creator.estimate_max_tokens_or_memory(py_executor)
397397
kv_cache_creator.teardown_managers(resources)
398398
del py_executor # free before constructing new
399399

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,6 @@ def __init__(
207207
kv_cache_config, KvCacheConfigCpp
208208
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfigCpp"
209209

210-
# overwrite max_tokens in VSWA case
211-
kv_cache_config.max_tokens = None
212210
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
213211
kv_cache_config=kv_cache_config,
214212
model_config=model_config,
@@ -636,7 +634,13 @@ def calculate_max_num_blocks_from_cpp(
636634
logger.debug(f"window_size_to_layers: {window_size_to_layers}")
637635

638636
free_mem, total_mem = torch.cuda.mem_get_info()
639-
primary_pool_memory_bytes = int(free_mem * 0.9)
637+
primary_pool_memory_bytes = int(free_mem)
638+
if kv_cache_config.max_free_gpu_memory_size is not None:
639+
# overwrite max_tokens in VSWA case, use max_free_gpu_memory_size instead
640+
kv_cache_config.max_tokens = None
641+
primary_pool_memory_bytes = min(
642+
kv_cache_config.max_free_gpu_memory_size,
643+
primary_pool_memory_bytes)
640644
secondary_pool_memory_bytes = 0
641645
logger.debug(
642646
f"primary_pool_memory_bytes is set to {primary_pool_memory_bytes/1024**3}GB, \nsecondary_pool_memory_bytes is set to {secondary_pool_memory_bytes/1024**3}GB"

tensorrt_llm/llmapi/llm_args.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,11 @@ class KvCacheConfig(BaseModel, PybindMirror):
796796
)
797797
use_uvm: bool = Field(default=False,
798798
description="Whether to use UVM for the KV cache.")
799+
max_free_gpu_memory_size: Optional[int] = Field(
800+
default=None,
801+
description=
802+
"The maximum size in bytes of GPU memory that can be allocated for the KV cache. This is only used for VSWA case for now as a alternative to `max_tokens`. If both `max_free_gpu_memory_size` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be allocated."
803+
)
799804

800805
def _to_pybind(self):
801806
return _KvCacheConfig(
@@ -811,7 +816,8 @@ def _to_pybind(self):
811816
event_buffer_max_size=self.event_buffer_max_size,
812817
enable_partial_reuse=self.enable_partial_reuse,
813818
copy_on_partial_reuse=self.copy_on_partial_reuse,
814-
use_uvm=self.use_uvm)
819+
use_uvm=self.use_uvm,
820+
max_free_gpu_memory_size=self.max_free_gpu_memory_size)
815821

816822

817823
@PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig)

0 commit comments

Comments
 (0)