Skip to content

Commit 45d9760

Browse files
committed
[Feature] Estimate max-model-len use available KV cache memory
Signed-off-by: rongfu.leng <[email protected]>
1 parent 995e3d1 commit 45d9760

File tree

2 files changed

+106
-5
lines changed

2 files changed

+106
-5
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
import pytest
44
import torch
55

6+
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
67
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
78
from vllm.sampling_params import SamplingParams
8-
from vllm.utils import sha256
9+
from vllm.utils import GiB_bytes, sha256
910
# disable yapf here as it formats differently than isort such that both fail
1011
# yapf: disable
1112
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
1213
FreeKVCacheBlockQueue, KVCacheBlock,
1314
PrefixCachingMetrics,
15+
estimate_max_model_len,
1416
generate_block_hash_extra_keys,
1517
hash_block_tokens,
1618
hash_request_tokens,
@@ -426,3 +428,45 @@ def new_kv_cache_spec(block_size=16,
426428
]
427429
with pytest.raises(AssertionError):
428430
unify_kv_cache_configs(diff_kv_cache_config)
431+
432+
433+
@pytest.mark.parametrize(
434+
("model_id", "max_model_len", "want_estimated_max_len"), [
435+
("Qwen/Qwen1.5-7B", 16385, 16384),
436+
("Qwen/Qwen1.5-7B", 16383, 16383),
437+
])
438+
def test_estimate_max_model_len(model_id, max_model_len,
439+
want_estimated_max_len):
440+
# Create a VllmConfig
441+
model_config = ModelConfig(
442+
model_id,
443+
task="generate",
444+
tokenizer=model_id,
445+
tokenizer_mode="auto",
446+
trust_remote_code=False,
447+
seed=0,
448+
dtype="float16",
449+
max_model_len=max_model_len,
450+
)
451+
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)
452+
453+
vllm_config = VllmConfig(
454+
model_config=model_config,
455+
scheduler_config=scheduler_config,
456+
)
457+
458+
# Create KV cache specs
459+
kv_cache_spec = {}
460+
for i in range(32):
461+
layer_name = f"layer_{i}"
462+
kv_cache_spec[layer_name] = FullAttentionSpec(
463+
block_size=16,
464+
num_kv_heads=32,
465+
head_size=128,
466+
dtype=torch.float16,
467+
use_mla=False,
468+
)
469+
# Estimate the maximum model length, 16384 model_len need 8GB
470+
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
471+
8 * GiB_bytes)
472+
assert estimated_max_len == want_estimated_max_len

vllm/v1/core/kv_cache_utils.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from vllm.config import VllmConfig
1010
from vllm.logger import init_logger
11-
from vllm.utils import sha256
11+
from vllm.utils import GiB_bytes, sha256
1212
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1313
KVCacheGroupSpec, KVCacheSpec,
1414
KVCacheTensor, SlidingWindowSpec)
@@ -459,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int,
459459
return ret
460460

461461

462+
def estimate_max_model_len(vllm_config: VllmConfig,
463+
kv_cache_spec: dict[str, KVCacheSpec],
464+
available_memory: int) -> int:
465+
"""
466+
Estimates the maximum model length that can fit in the available memory
467+
using binary search.
468+
469+
Args:
470+
vllm_config: The global VllmConfig
471+
kv_cache_spec: The kv cache spec of each attention layer in the model
472+
available_memory: Memory available for KV cache in bytes.
473+
474+
Returns:
475+
The estimated maximum model length that can fit in the available memory.
476+
"""
477+
478+
# Define a function to check if a given model length fits in memory
479+
def fits_in_memory(model_len: int) -> bool:
480+
# Modify the max_model_len for this calculation
481+
vllm_config.model_config.max_model_len = model_len
482+
# Calculate memory needed for the given model length
483+
memory_needed = sum(
484+
(layer_spec.max_memory_usage_bytes(vllm_config)
485+
for layer_spec in kv_cache_spec.values()),
486+
start=0,
487+
)
488+
return memory_needed <= available_memory
489+
490+
# Binary search for the maximum model length
491+
current_max = vllm_config.model_config.max_model_len
492+
left, right = 1, current_max
493+
494+
# If even the smallest model length doesn't fit, return 0
495+
if not fits_in_memory(left):
496+
return 0
497+
498+
# Binary search for the maximum model length that fits
499+
result = 1
500+
while left <= right:
501+
mid = (left + right) // 2
502+
if fits_in_memory(mid):
503+
result = mid
504+
left = mid + 1
505+
else:
506+
right = mid - 1
507+
return result
508+
509+
462510
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
463511
kv_cache_spec: dict[str, KVCacheSpec],
464512
available_memory: int):
@@ -486,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
486534
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
487535

488536
if needed_memory > available_memory:
537+
# Estimate the maximum model length that can fit in the available memory
538+
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
539+
available_memory)
540+
estimated_msg = ""
541+
if estimated_max_len > 0:
542+
estimated_msg = " Based on the available memory,"
543+
f" the estimated maximum model length is {estimated_max_len}."
544+
489545
raise ValueError(
490546
f"To serve at least one request with the models's max seq len "
491-
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV "
547+
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
492548
f"cache is needed, which is larger than the available KV cache "
493-
f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try "
494-
f"increasing `gpu_memory_utilization` or decreasing "
549+
f"memory ({available_memory/GiB_bytes:.2f} GiB)."
550+
f"{estimated_msg} "
551+
f" Try increasing `gpu_memory_utilization` or decreasing "
495552
f"`max_model_len` when initializing the engine.")
496553

497554

0 commit comments

Comments
 (0)