From 0827ca87e3effab46b5ffa2d3a43cd20c4c9c19c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 11 Jan 2025 05:38:08 -0800 Subject: [PATCH 01/20] can run Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 2 + vllm/v1/core/kv_cache_utils.py | 75 ++++++++++++++++++++++++++ vllm/v1/engine/core.py | 22 ++++---- vllm/v1/executor/abstract.py | 11 ++-- vllm/v1/executor/multiproc_executor.py | 24 ++++----- vllm/v1/executor/ray_executor.py | 40 +++++++------- vllm/v1/executor/uniproc_executor.py | 20 +++---- vllm/v1/kv_cache_interface.py | 66 +++++++++++++++++++++++ vllm/v1/utils.py | 12 +++++ vllm/v1/worker/gpu_model_runner.py | 69 +++++++++++++++++++----- vllm/v1/worker/gpu_worker.py | 39 ++++---------- 11 files changed, 281 insertions(+), 99 deletions(-) create mode 100644 vllm/v1/kv_cache_interface.py diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 55e4e14027f7..2ac1eb5180ac 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -101,7 +101,9 @@ def __init__( self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dtype = dtype # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 22a5d2fb08a4..2366afc36f36 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -3,7 +3,10 @@ from dataclasses import dataclass from typing import Any, List, NamedTuple, Optional, Tuple +from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, + KVCacheTensor) from vllm.v1.request import Request logger = init_logger(__name__) @@ -305,3 +308,75 @@ def hash_request_tokens(block_size: int, ret.append(block_hash) parent_block_hash_value = block_hash.hash_value return ret + + +def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, + available_memory: int) -> Tuple[KVCacheConfig, int]: + check_enough_memory(vllm_config, kv_cache_spec, available_memory) + if is_same_key(kv_cache_spec): + # kv_cache of all layers are the same + return _get_kv_cache_config_same_key(vllm_config, kv_cache_spec, + available_memory) + else: + raise NotImplementedError + + +def _get_kv_cache_config_same_key( + vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, + available_memory: int) -> Tuple[KVCacheConfig, int]: + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + assert len(page_sizes) == 1 + page_size = page_sizes.pop() + + num_gpu_blocks = int(available_memory // page_size // len(kv_cache_spec)) + num_gpu_blocks = max(num_gpu_blocks, 0) + + logger.info("# GPU blocks: %d", num_gpu_blocks) + + if vllm_config.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = \ + vllm_config.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + per_layer_size = page_size * num_gpu_blocks + + kv_cache_config = KVCacheConfig( + tensors={ + layer_name: KVCacheTensor(size=per_layer_size) + for layer_name in kv_cache_spec + }, + groups=[[layer_name for layer_name in kv_cache_spec]], + kv_cache_spec=kv_cache_spec) + return kv_cache_config, num_gpu_blocks + + +def is_same_key(kv_cache_spec: KVCacheSpec) -> bool: + layer_keys = set(layer.key for layer in kv_cache_spec.values()) + return len(layer_keys) == 1 + + +def check_enough_memory(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, + available_memory: int): + if available_memory <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_model_len = vllm_config.model_config.max_model_len + needed_memory = 0 + for layer_spec in kv_cache_spec.values(): + needed_memory += layer_spec.bytes_for_tokens(max_model_len) + + if needed_memory > available_memory: + # TODO(Chen): need unit test + raise ValueError( + f"To serve at least one request with the models's max seq len " + f"({max_model_len}), ({needed_memory/1024/1024/1024} GB KV cache is" + f"needed, which is larger than the available KV Cache memory " + f"({available_memory/1024/1024/1024} GB). Try increasing " + f"`gpu_memory_utilization` or decreasing `max_model_len` when " + f"initializing the engine.") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 975ce11fe8af..19a7c560551a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -11,11 +11,12 @@ import zmq.asyncio from msgspec import msgpack -from vllm.config import CacheConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx +from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, @@ -53,7 +54,7 @@ def __init__( # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches( - vllm_config.cache_config) + vllm_config) vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks @@ -68,21 +69,16 @@ def __init__( vllm_config.model_config) def _initialize_kv_caches(self, - cache_config: CacheConfig) -> Tuple[int, int]: + vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() - num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( - ) + kv_cache_spec = self.model_executor.get_kv_cache_spec() + availble_gpu_memory = self.model_executor.get_available_memory() - if cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override + kv_cache_config, num_gpu_blocks = get_kv_cache_config( + vllm_config, kv_cache_spec, availble_gpu_memory) num_cpu_blocks = 0 - self.model_executor.initialize(num_gpu_blocks) + self.model_executor.initialize(kv_cache_config) elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 5d74d4b01f50..2248f629cdc6 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod -from typing import Tuple, Type +from typing import Type from vllm.config import VllmConfig +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput @@ -30,11 +31,15 @@ def __init__(self, vllm_config: VllmConfig) -> None: raise NotImplementedError @abstractmethod - def initialize(self, num_gpu_blocks: int) -> None: + def initialize(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError @abstractmethod - def determine_num_available_blocks(self) -> Tuple[int, int]: + def get_available_memory(self) -> int: # in bytes + raise NotImplementedError + + @abstractmethod + def get_kv_cache_spec(self) -> KVCacheSpec: raise NotImplementedError @abstractmethod diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 41e6abbd6795..7efdd78b16c0 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -23,6 +23,7 @@ from vllm.utils import (get_distributed_init_method, get_mp_context, get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase @@ -90,29 +91,26 @@ def sigusr1_handler(signum, frame): for w in self.workers: w.worker_response_mq.wait_until_ready() - def initialize(self, num_gpu_blocks: int) -> None: + def initialize(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ - logger.info("# GPU blocks: %d", num_gpu_blocks) - self.collective_rpc("initialize_cache", args=(num_gpu_blocks, )) + self.collective_rpc("initialize_cache", args=(kv_cache_config, )) self.collective_rpc("compile_or_warm_up_model") - def determine_num_available_blocks(self) -> Tuple[int, int]: - """ - Determine the number of available KV blocks by invoking the - underlying worker. - """ - num_blocks = self.collective_rpc("determine_num_available_blocks") + def get_available_memory(self) -> int: + memory_sizes = self.collective_rpc("get_available_memory") # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory + # memory size across all workers to make sure all the memory # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) + return min(memory_sizes) - return num_gpu_blocks, num_cpu_blocks + def get_kv_cache_spec(self) -> KVCacheSpec: + kv_cache_specs = self.collective_rpc("get_kv_cache_spec") + assert all(lc == kv_cache_specs[0] for lc in kv_cache_specs) + return kv_cache_specs[0] def collective_rpc(self, method: str, diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 79acc60001c9..27ffb3abdc98 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -10,6 +10,7 @@ from vllm.v1.executor.abstract import Executor from vllm.v1.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, ray) +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput if ray is not None: @@ -211,39 +212,40 @@ def _get_worker_kwargs( distributed_init_method=distributed_init_method, ) - def determine_num_available_blocks(self) -> Tuple[int, int]: + def get_available_memory(self) -> int: """ - Determine the number of available KV blocks. + Determine the available GPU memory in bytes. - This invokes `determine_num_available_blocks` on each worker and takes + This invokes `get_available_memory` on each worker and takes the min of the results, guaranteeing that the selected cache sizes are compatible with all workers. - - Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks") + + memory_sizes = self._run_workers("get_available_memory") # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory + # memory size across all workers to make sure all the memory # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) + return min(memory_sizes) - return num_gpu_blocks, num_cpu_blocks - - def initialize(self, num_gpu_blocks: int) -> None: + def initialize(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the KV cache in all workers. """ - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# GPU blocks: %d", num_gpu_blocks) - self._run_workers("initialize_cache", num_gpu_blocks) + self._run_workers("initialize_cache", kv_cache_config) self._run_workers("compile_or_warm_up_model") + def get_kv_cache_spec(self) -> KVCacheSpec: + """ + Get the KVCacheSpec of the model + + This invokes `get_kv_cache_spec` on each worker and asserts that + they are identical. The KVCacheSpec is then returned. + """ + kv_cache_specs = self._run_workers("get_kv_cache_spec") + assert all(lc == kv_cache_specs[0] for lc in kv_cache_specs) + return kv_cache_specs[0] + def _run_workers( self, method: str, diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index be058318de58..e30ba99b38b7 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -1,10 +1,11 @@ import os -from typing import Optional, Tuple +from typing import Optional from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_worker import Worker @@ -49,20 +50,19 @@ def _create_worker( distributed_init_method=distributed_init_method, ) - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the + def get_available_memory(self) -> int: + """Determine the available memory for KV cache by invoking the underlying worker. """ - return self.worker.determine_num_available_blocks() + return self.worker.get_available_memory() - def initialize(self, num_gpu_blocks: int) -> None: + def get_kv_cache_spec(self) -> KVCacheSpec: + return self.worker.get_kv_cache_spec() + + def initialize(self, kv_cache_config: KVCacheConfig) -> None: """Initialize the KV cache by invoking the underlying worker. """ - # NOTE: This is logged in the executor because there can be >1 worker - # with other executors. We could log in the engine level, but work - # remains to abstract away the device for non-GPU configurations. - logger.info("# GPU blocks: %d", num_gpu_blocks) - self.worker.initialize_cache(num_gpu_blocks) + self.worker.initialize_cache(kv_cache_config) self.worker.compile_or_warm_up_model() def execute_model( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py new file mode 100644 index 000000000000..f7ef36325454 --- /dev/null +++ b/vllm/v1/kv_cache_interface.py @@ -0,0 +1,66 @@ +from dataclasses import dataclass +import math +from typing import Dict, List, Protocol, Tuple, runtime_checkable +import torch + +from vllm.config import ModelConfig, VllmConfig +from vllm.utils import get_dtype_size +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class KVCacheSpecBase: + block_size: int + + @property + def key(self) -> str: + raise NotImplementedError + + @property + def page_size_bytes(self) -> int: + raise NotImplementedError + + def bytes_for_tokens(self, num_tokens: int) -> int: + raise NotImplementedError + + +@dataclass +class FullAttentionSpec(KVCacheSpecBase): + num_kv_heads: int + head_size: int + dtype: torch.dtype + + @property + def key(self) -> str: + return f"full_attention_{self.block_size}_{self.bytes_for_tokens(1)}" + + @property + def page_size_bytes(self) -> int: + return 2 * self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + + def bytes_for_tokens(self, num_tokens: int) -> int: + return math.ceil(num_tokens / self.block_size) * self.page_size_bytes + + +KVCacheSpec = Dict[str, KVCacheSpecBase] + + +@dataclass +class KVCacheTensor: + size: int # in bytes + + +@dataclass +class KVCacheConfig: + # layer_name -> the kv_cache tensor configuration for the layer + tensors: Dict[str, KVCacheTensor] + + # [group_id][layer_name in the group]. One group containing all + # layer_names if the Spec for kv_cache of all layers are the same + groups: List[List[str]] + + # the KVCacheSpec of the model + kv_cache_spec: KVCacheSpec diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index b0a7affbebb7..1f1691ad351a 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -5,6 +5,8 @@ from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, overload) +import torch + from vllm.logger import init_logger from vllm.utils import get_mp_context, kill_process_tree @@ -134,3 +136,13 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): socket_file = ipc_socket.replace("ipc://", "") if os and os.path.exists(socket_file): os.remove(socket_file) + + +def bind_kv_cache( + ctx: Dict[str, Any], + kv_caches: Dict[str, torch.Tensor], +) -> None: + for layer_name, kv_cache in kv_caches.items(): + # TODO: change [kv_cache] to kv_cache when dropping v0 which uses + # virtual engine to support PP. + ctx[layer_name].kv_cache = [kv_cache] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fb87dc5a8222..56f5ce56255b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -7,6 +7,8 @@ import torch.distributed import torch.nn as nn +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context @@ -16,13 +18,15 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, bind_kv_cache, cdiv, - is_pin_memory_available) + LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.engine.mm_input_mapper import MMInputMapperClient +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -852,15 +856,54 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def initialize_kv_cache(self, num_blocks: int) -> None: + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: assert len(self.kv_caches) == 0 - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) - for _ in range(self.num_attn_layers): - self.kv_caches.append( - torch.zeros(kv_cache_shape, - dtype=self.kv_cache_dtype, - device=self.device)) - bind_kv_cache( - self.vllm_config.compilation_config.static_forward_context, - [self.kv_caches]) + if len(kv_cache_config.groups) > 1: + raise NotImplementedError("Multiple groups are not supported yet.") + + kv_caches: Dict[str, torch.Tensor] = {} + + for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): + tensor_config = kv_cache_config.tensors[layer_name] + assert tensor_config.size % layer_spec.page_size_bytes == 0 + num_blocks = tensor_config.size // layer_spec.page_size_bytes + if isinstance(layer_spec, FullAttentionSpec): + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, + layer_spec.head_size) + dtype = layer_spec.dtype + kv_caches[layer_name] = torch.zeros(kv_cache_shape, + dtype=dtype, + device=self.device) + else: + raise NotImplementedError + + forward_ctx = self.vllm_config.compilation_config.static_forward_context + bind_kv_cache(forward_ctx, kv_caches) + + def get_kv_cache_spec(self) -> KVCacheSpec: + forward_ctx = self.vllm_config.compilation_config.static_forward_context + block_size = self.vllm_config.cache_config.block_size + kv_cache_spec: KVCacheSpec = {} + for layer_name, attn_module in forward_ctx.items(): + # TODO: Support other attention modules, e.g., sliding window, + # cross-attention, MLA. + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + ) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + + return kv_cache_spec diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index af438f7d5820..6e4b1a83a789 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional import torch import torch.distributed @@ -16,6 +16,7 @@ from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size from vllm.v1.core.scheduler import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -111,7 +112,7 @@ def load_model(self) -> None: self.model_runner.load_model() @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: + def get_available_memory(self) -> int: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. @@ -122,6 +123,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: .. tip:: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. + TODO (Chen): update comments """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. @@ -161,33 +163,14 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - cache_block_size = _get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - return num_gpu_blocks, 0 - - def initialize_cache(self, num_gpu_blocks: int) -> None: - """Allocate GPU and CPU KV cache with the specified number of blocks.""" - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - - max_seq_len = self.cache_config.block_size * num_gpu_blocks - max_model_len = self.model_config.max_model_len - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") + return int(available_kv_cache_memory) + + def get_kv_cache_spec(self) -> KVCacheSpec: + return self.model_runner.get_kv_cache_spec() - self.model_runner.initialize_kv_cache(num_gpu_blocks) + def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: + """Allocate GPU KV cache with the specified kv_cache_config.""" + self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: From 6025d5ea8ce275201cf6d2aaa9334fd854fda9a3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 11 Jan 2025 06:07:30 -0800 Subject: [PATCH 02/20] update comment Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 8 +++++--- vllm/v1/engine/core.py | 10 +++++++++- vllm/v1/executor/multiproc_executor.py | 8 ++++++++ vllm/v1/executor/ray_executor.py | 2 +- vllm/v1/executor/uniproc_executor.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 5 +++-- vllm/v1/worker/gpu_worker.py | 10 +++------- 7 files changed, 32 insertions(+), 14 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2366afc36f36..be876ba09aeb 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -312,9 +312,11 @@ def hash_request_tokens(block_size: int, def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, available_memory: int) -> Tuple[KVCacheConfig, int]: + """ returns the KV cache configuration and the number of GPU blocks + """ check_enough_memory(vllm_config, kv_cache_spec, available_memory) if is_same_key(kv_cache_spec): - # kv_cache of all layers are the same + # kv_cache of all layers are the same, which is true for most models return _get_kv_cache_config_same_key(vllm_config, kv_cache_spec, available_memory) else: @@ -331,8 +333,6 @@ def _get_kv_cache_config_same_key( num_gpu_blocks = int(available_memory // page_size // len(kv_cache_spec)) num_gpu_blocks = max(num_gpu_blocks, 0) - logger.info("# GPU blocks: %d", num_gpu_blocks) - if vllm_config.cache_config.num_gpu_blocks_override is not None: num_gpu_blocks_override = \ vllm_config.cache_config.num_gpu_blocks_override @@ -342,6 +342,8 @@ def _get_kv_cache_config_same_key( num_gpu_blocks_override) num_gpu_blocks = num_gpu_blocks_override + logger.info("# GPU blocks: %d", num_gpu_blocks) + per_layer_size = page_size * num_gpu_blocks kv_cache_config = KVCacheConfig( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 19a7c560551a..d8bd7ed30338 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -71,14 +71,22 @@ def __init__( def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() + + # Get all kv cache tensor needed by the model kv_cache_spec = self.model_executor.get_kv_cache_spec() + + # Profiles the peak memory usage of the model to determine how much + # memory can be allocated for kv cache. availble_gpu_memory = self.model_executor.get_available_memory() + # Get the kv cache tensor size kv_cache_config, num_gpu_blocks = get_kv_cache_config( vllm_config, kv_cache_spec, availble_gpu_memory) - num_cpu_blocks = 0 + + # Initialize kv cache and warmup the execution self.model_executor.initialize(kv_cache_config) + elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 7efdd78b16c0..3f6261bdd502 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -100,6 +100,10 @@ def initialize(self, kv_cache_config: KVCacheConfig) -> None: self.collective_rpc("compile_or_warm_up_model") def get_available_memory(self) -> int: + """ + Determine the available memory for KV cache by invoking the + underlying worker. + """ memory_sizes = self.collective_rpc("get_available_memory") # Since we use a shared centralized controller, we take the minimum @@ -108,6 +112,10 @@ def get_available_memory(self) -> int: return min(memory_sizes) def get_kv_cache_spec(self) -> KVCacheSpec: + """ + Get all kv cache tensor needed by the model by invoking the + underlying worker. + """ kv_cache_specs = self.collective_rpc("get_kv_cache_spec") assert all(lc == kv_cache_specs[0] for lc in kv_cache_specs) return kv_cache_specs[0] diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 27ffb3abdc98..c81c4eaff4d2 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -237,7 +237,7 @@ def initialize(self, kv_cache_config: KVCacheConfig) -> None: def get_kv_cache_spec(self) -> KVCacheSpec: """ - Get the KVCacheSpec of the model + Get all kv cache tensor needed by the model This invokes `get_kv_cache_spec` on each worker and asserts that they are identical. The KVCacheSpec is then returned. diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index e30ba99b38b7..30d257d98459 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -57,6 +57,9 @@ def get_available_memory(self) -> int: return self.worker.get_available_memory() def get_kv_cache_spec(self) -> KVCacheSpec: + """Get all kv cache tensor needed by the model by invoking the + underlying worker. + """ return self.worker.get_kv_cache_spec() def initialize(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 56f5ce56255b..4133e5b731f3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -878,8 +878,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: else: raise NotImplementedError - forward_ctx = self.vllm_config.compilation_config.static_forward_context - bind_kv_cache(forward_ctx, kv_caches) + bind_kv_cache( + self.vllm_config.compilation_config.static_forward_context, + kv_caches) def get_kv_cache_spec(self) -> KVCacheSpec: forward_ctx = self.vllm_config.compilation_config.static_forward_context diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6e4b1a83a789..7dfb7f37d5ed 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -113,20 +113,16 @@ def load_model(self) -> None: @torch.inference_mode() def get_available_memory(self) -> int: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. + """Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. + Then, it calculate the free memory that can be used for KV cache .. tip:: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. - TODO (Chen): update comments """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() From 6024290e487199fc842ab319979f05eaae4bf59d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 11 Jan 2025 07:16:07 -0800 Subject: [PATCH 03/20] update comment Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index be876ba09aeb..ec3e7fc0189b 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -316,7 +316,8 @@ def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, """ check_enough_memory(vllm_config, kv_cache_spec, available_memory) if is_same_key(kv_cache_spec): - # kv_cache of all layers are the same, which is true for most models + # kv cache of all layers are the same, which is true for most models. + # Allocate the same amount of memory for each layer. return _get_kv_cache_config_same_key(vllm_config, kv_cache_spec, available_memory) else: @@ -374,7 +375,6 @@ def check_enough_memory(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, needed_memory += layer_spec.bytes_for_tokens(max_model_len) if needed_memory > available_memory: - # TODO(Chen): need unit test raise ValueError( f"To serve at least one request with the models's max seq len " f"({max_model_len}), ({needed_memory/1024/1024/1024} GB KV cache is" From 03130cd41d89d612e612b1bb8cb7806e16d586b3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 11 Jan 2025 07:32:25 -0800 Subject: [PATCH 04/20] format Signed-off-by: Chen Zhang --- vllm/v1/kv_cache_interface.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index f7ef36325454..394ca04d56d5 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,9 +1,8 @@ from dataclasses import dataclass import math -from typing import Dict, List, Protocol, Tuple, runtime_checkable +from typing import Dict, List, runtime_checkable import torch -from vllm.config import ModelConfig, VllmConfig from vllm.utils import get_dtype_size from vllm.logger import init_logger From 1229600b61060e59ef5006b9ed0e4de620396297 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 11 Jan 2025 07:33:21 -0800 Subject: [PATCH 05/20] format Signed-off-by: Chen Zhang --- vllm/v1/kv_cache_interface.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 394ca04d56d5..0f42e3bb0a00 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,10 +1,11 @@ -from dataclasses import dataclass import math -from typing import Dict, List, runtime_checkable +from dataclasses import dataclass +from typing import Dict, List + import torch -from vllm.utils import get_dtype_size from vllm.logger import init_logger +from vllm.utils import get_dtype_size logger = init_logger(__name__) From e3764d47ee2d5bf4c48d3efb715eb4fe3444da90 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 11 Jan 2025 19:18:00 -0800 Subject: [PATCH 06/20] bind kv cache to model runner Signed-off-by: Chen Zhang --- tests/v1/test_utils.py | 93 ++++++++++++++++++++++++++++++ vllm/v1/utils.py | 16 +++++ vllm/v1/worker/gpu_model_runner.py | 3 +- 3 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 tests/v1/test_utils.py diff --git a/tests/v1/test_utils.py b/tests/v1/test_utils.py new file mode 100644 index 000000000000..74f702d615f9 --- /dev/null +++ b/tests/v1/test_utils.py @@ -0,0 +1,93 @@ +from typing import List + +import torch + +from vllm.v1.utils import bind_kv_cache + + +def test_bind_kv_cache(): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + 'layers.1.self_attn': Attention(32, 128, 0.1), + 'layers.2.self_attn': Attention(32, 128, 0.1), + 'layers.3.self_attn': Attention(32, 128, 0.1), + } + kv_cache = { + 'layers.0.self_attn': torch.zeros((1, )), + 'layers.1.self_attn': torch.zeros((1, )), + 'layers.2.self_attn': torch.zeros((1, )), + 'layers.3.self_attn': torch.zeros((1, )), + } + runner_kv_caches: List[torch.Tensor] = [] + bind_kv_cache(ctx, runner_kv_caches, kv_cache) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ + 'layers.0.self_attn'] + assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ + 'layers.1.self_attn'] + assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ + 'layers.2.self_attn'] + assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ + 'layers.3.self_attn'] + + assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] + assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] + assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] + assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] + + +def test_bind_kv_cache_non_attention(): + from vllm.attention import Attention + + # example from Jamba PP=2 + ctx = { + 'model.layers.20.attn': Attention(32, 128, 0.1), + 'model.layers.28.attn': Attention(32, 128, 0.1), + } + kv_cache = { + 'model.layers.20.attn': torch.zeros((1, )), + 'model.layers.28.attn': torch.zeros((1, )), + } + + runner_kv_caches: List[torch.Tensor] = [] + bind_kv_cache(ctx, runner_kv_caches, kv_cache) + + assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ + 'model.layers.20.attn'] + assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ + 'model.layers.28.attn'] + + assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] + assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] + + +def test_bind_kv_cache_encoder_decoder(): + from vllm.attention import Attention, AttentionType + + # example from bart + ctx = { + 'encoder.layers.0.self_attn.attn': + Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), + 'decoder.layers.0.encoder_attn.attn': + Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), + 'decoder.layers.0.self_attn.attn': + Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), + } + + kv_cache_tensor = torch.zeros((1, )) + kv_cache = { + 'decoder.layers.0.encoder_attn.attn': kv_cache_tensor, + 'decoder.layers.0.self_attn.attn': kv_cache_tensor, + } + encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache + + runner_kv_caches: List[torch.Tensor] = [] + bind_kv_cache(ctx, runner_kv_caches, kv_cache) + assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache + assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[ + 'decoder.layers.0.encoder_attn.attn'] + assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[ + 'decoder.layers.0.self_attn.attn'] + + assert runner_kv_caches[0] is kv_cache_tensor diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 1f1691ad351a..94f21e506542 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,6 +1,7 @@ import multiprocessing import os import weakref +from collections import defaultdict from collections.abc import Sequence from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar, Union, overload) @@ -8,6 +9,7 @@ import torch from vllm.logger import init_logger +from vllm.model_executor.models.utils import extract_layer_index from vllm.utils import get_mp_context, kill_process_tree logger = init_logger(__name__) @@ -140,8 +142,22 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): def bind_kv_cache( ctx: Dict[str, Any], + runner_kv_caches: List[torch.Tensor], kv_caches: Dict[str, torch.Tensor], ) -> None: + # bind kv_caches to ModelRunner's kv_caches + assert len(runner_kv_caches) == 0 + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + for layer_name in layer_names[1:]: + assert kv_caches[layer_name] is kv_caches[layer_names[0]] + runner_kv_caches.append(kv_caches[layer_names[0]]) + + # bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): # TODO: change [kv_cache] to kv_cache when dropping v0 which uses # virtual engine to support PP. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4133e5b731f3..9a505527a1e8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -857,7 +857,6 @@ def capture_model(self) -> None: elapsed_time, cuda_graph_size / (1 << 30)) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: - assert len(self.kv_caches) == 0 if len(kv_cache_config.groups) > 1: raise NotImplementedError("Multiple groups are not supported yet.") @@ -880,7 +879,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, - kv_caches) + self.kv_caches, kv_caches) def get_kv_cache_spec(self) -> KVCacheSpec: forward_ctx = self.vllm_config.compilation_config.static_forward_context From fec7d2d4b3ea6a68a845f4cffe630c3208f90e6e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 13 Jan 2025 19:23:43 -0800 Subject: [PATCH 07/20] determine_available_memory Signed-off-by: Chen Zhang --- vllm/v1/engine/core.py | 2 +- vllm/v1/executor/abstract.py | 2 +- vllm/v1/executor/multiproc_executor.py | 4 ++-- vllm/v1/executor/ray_executor.py | 6 +++--- vllm/v1/executor/uniproc_executor.py | 4 ++-- vllm/v1/worker/gpu_worker.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d8bd7ed30338..a9c53c9e362e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -77,7 +77,7 @@ def _initialize_kv_caches(self, # Profiles the peak memory usage of the model to determine how much # memory can be allocated for kv cache. - availble_gpu_memory = self.model_executor.get_available_memory() + availble_gpu_memory = self.model_executor.determine_available_memory() # Get the kv cache tensor size kv_cache_config, num_gpu_blocks = get_kv_cache_config( diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 2248f629cdc6..9b7d87514064 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -35,7 +35,7 @@ def initialize(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError @abstractmethod - def get_available_memory(self) -> int: # in bytes + def determine_available_memory(self) -> int: # in bytes raise NotImplementedError @abstractmethod diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 3f6261bdd502..0fbc67bd8509 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -99,12 +99,12 @@ def initialize(self, kv_cache_config: KVCacheConfig) -> None: self.collective_rpc("initialize_cache", args=(kv_cache_config, )) self.collective_rpc("compile_or_warm_up_model") - def get_available_memory(self) -> int: + def determine_available_memory(self) -> int: """ Determine the available memory for KV cache by invoking the underlying worker. """ - memory_sizes = self.collective_rpc("get_available_memory") + memory_sizes = self.collective_rpc("determine_available_memory") # Since we use a shared centralized controller, we take the minimum # memory size across all workers to make sure all the memory diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index c81c4eaff4d2..ea348aea89c9 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -212,16 +212,16 @@ def _get_worker_kwargs( distributed_init_method=distributed_init_method, ) - def get_available_memory(self) -> int: + def determine_available_memory(self) -> int: """ Determine the available GPU memory in bytes. - This invokes `get_available_memory` on each worker and takes + This invokes `determine_available_memory` on each worker and takes the min of the results, guaranteeing that the selected cache sizes are compatible with all workers. """ - memory_sizes = self._run_workers("get_available_memory") + memory_sizes = self._run_workers("determine_available_memory") # Since we use a shared centralized controller, we take the minimum # memory size across all workers to make sure all the memory diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 30d257d98459..e836ed6ce14a 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -50,11 +50,11 @@ def _create_worker( distributed_init_method=distributed_init_method, ) - def get_available_memory(self) -> int: + def determine_available_memory(self) -> int: """Determine the available memory for KV cache by invoking the underlying worker. """ - return self.worker.get_available_memory() + return self.worker.determine_available_memory() def get_kv_cache_spec(self) -> KVCacheSpec: """Get all kv cache tensor needed by the model by invoking the diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 7dfb7f37d5ed..81e9640b7b30 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -112,7 +112,7 @@ def load_model(self) -> None: self.model_runner.load_model() @torch.inference_mode() - def get_available_memory(self) -> int: + def determine_available_memory(self) -> int: """Profiles the peak memory usage of the model to determine how much memory can be used for KV cache without OOMs. From 4294435034600cad1d4f6f26fc0510e24db392b2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 13 Jan 2025 23:25:35 -0800 Subject: [PATCH 08/20] update kv_cache_utils Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 77 +++++++++++++++++----------------- vllm/v1/kv_cache_interface.py | 5 ++- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index ec3e7fc0189b..79614ef1840f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -310,21 +310,35 @@ def hash_request_tokens(block_size: int, return ret -def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, - available_memory: int) -> Tuple[KVCacheConfig, int]: - """ returns the KV cache configuration and the number of GPU blocks - """ - check_enough_memory(vllm_config, kv_cache_spec, available_memory) - if is_same_key(kv_cache_spec): - # kv cache of all layers are the same, which is true for most models. - # Allocate the same amount of memory for each layer. - return _get_kv_cache_config_same_key(vllm_config, kv_cache_spec, - available_memory) - else: - raise NotImplementedError +def check_enough_kv_cache_memory(vllm_config: VllmConfig, + kv_cache_spec: KVCacheSpec, + available_memory: int): + if available_memory <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + + max_model_len = vllm_config.model_config.max_model_len + needed_memory = 0 + for layer_spec in kv_cache_spec.values(): + needed_memory += layer_spec.bytes_for_tokens(max_model_len) + + if needed_memory > available_memory: + raise ValueError( + f"To serve at least one request with the models's max seq len " + f"({max_model_len}), ({needed_memory/1024/1024/1024} GB KV cache is" + f"needed, which is larger than the available KV Cache memory " + f"({available_memory/1024/1024/1024} GB). Try increasing " + f"`gpu_memory_utilization` or decreasing `max_model_len` when " + f"initializing the engine.") + + +def is_same_type(kv_cache_spec: KVCacheSpec) -> bool: + layer_keys = set(layer.type_key for layer in kv_cache_spec.values()) + return len(layer_keys) == 1 -def _get_kv_cache_config_same_key( +def _get_kv_cache_config_same_type( vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, available_memory: int) -> Tuple[KVCacheConfig, int]: page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} @@ -357,28 +371,15 @@ def _get_kv_cache_config_same_key( return kv_cache_config, num_gpu_blocks -def is_same_key(kv_cache_spec: KVCacheSpec) -> bool: - layer_keys = set(layer.key for layer in kv_cache_spec.values()) - return len(layer_keys) == 1 - - -def check_enough_memory(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, - available_memory: int): - if available_memory <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - - max_model_len = vllm_config.model_config.max_model_len - needed_memory = 0 - for layer_spec in kv_cache_spec.values(): - needed_memory += layer_spec.bytes_for_tokens(max_model_len) - - if needed_memory > available_memory: - raise ValueError( - f"To serve at least one request with the models's max seq len " - f"({max_model_len}), ({needed_memory/1024/1024/1024} GB KV cache is" - f"needed, which is larger than the available KV Cache memory " - f"({available_memory/1024/1024/1024} GB). Try increasing " - f"`gpu_memory_utilization` or decreasing `max_model_len` when " - f"initializing the engine.") +def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, + available_memory: int) -> Tuple[KVCacheConfig, int]: + """ returns the KV cache configuration and the number of GPU blocks + """ + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + if is_same_type(kv_cache_spec): + # kv cache of all layers are the same, which is true for most models. + # Allocate the same amount of memory for each layer. + return _get_kv_cache_config_same_type(vllm_config, kv_cache_spec, + available_memory) + else: + raise NotImplementedError diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 0f42e3bb0a00..4d592c21ef36 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -15,7 +15,8 @@ class KVCacheSpecBase: block_size: int @property - def key(self) -> str: + def type_key(self) -> str: + # TODO: add docstring raise NotImplementedError @property @@ -33,7 +34,7 @@ class FullAttentionSpec(KVCacheSpecBase): dtype: torch.dtype @property - def key(self) -> str: + def type_key(self) -> str: return f"full_attention_{self.block_size}_{self.bytes_for_tokens(1)}" @property From f79dff269e417f7af3b6555312a1215111e69b66 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 14 Jan 2025 14:40:30 +0800 Subject: [PATCH 09/20] Update vllm/v1/utils.py Co-authored-by: Cody Yu Signed-off-by: Chen Zhang --- vllm/v1/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 94f21e506542..7e40698d4668 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -153,9 +153,9 @@ def bind_kv_cache( for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] - for layer_name in layer_names[1:]: - assert kv_caches[layer_name] is kv_caches[layer_names[0]] - runner_kv_caches.append(kv_caches[layer_names[0]]) + layer_name = layer_names[0] + assert all(kv_caches[n] is kv_caches[layer_name] for n in layer_names[1:]) + runner_kv_caches.append(kv_caches[layer_name]) # bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): From 97176dabefc2232139df3b90ef4c75e74e0b8a6a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 14 Jan 2025 14:40:39 +0800 Subject: [PATCH 10/20] Update vllm/v1/worker/gpu_model_runner.py Co-authored-by: Cody Yu Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9a505527a1e8..9e8182f59e54 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -858,7 +858,7 @@ def capture_model(self) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if len(kv_cache_config.groups) > 1: - raise NotImplementedError("Multiple groups are not supported yet.") + raise NotImplementedError("Hybrid kv-cache groups are not supported yet.") kv_caches: Dict[str, torch.Tensor] = {} From e6179a8de12cd43382f22ff2a01008fac9862ee4 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 14 Jan 2025 00:23:55 -0800 Subject: [PATCH 11/20] add some comments Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 42 +++++++++++++++++++++++++++++- vllm/v1/utils.py | 25 +++++++++++++----- vllm/v1/worker/gpu_model_runner.py | 18 ++++++++++++- 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 79614ef1840f..46c60005fab0 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -313,6 +313,17 @@ def hash_request_tokens(block_size: int, def check_enough_kv_cache_memory(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, available_memory: int): + """ + Checks if there is enough memory available for the KV cache of at least one + request with the model's max_model_len. + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of the model + available_memory (int): Memory available for KV cache in bytes. + Raises: + ValueError: If there is not enough memory available for the KV cache. + """ + if available_memory <= 0: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " @@ -334,6 +345,14 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, def is_same_type(kv_cache_spec: KVCacheSpec) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same type of KV cache. + Args: + kv_cache_spec (KVCacheSpec): The KVCacheSpec of the model + Returns: + True if all layers have the same type, False otherwise. + """ + layer_keys = set(layer.type_key for layer in kv_cache_spec.values()) return len(layer_keys) == 1 @@ -341,6 +360,18 @@ def is_same_type(kv_cache_spec: KVCacheSpec) -> bool: def _get_kv_cache_config_same_type( vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, available_memory: int) -> Tuple[KVCacheConfig, int]: + """ + Generates the KV cache configuration for a model with one type of KV cache. + Divide the available memory equally among all layers. + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of the model + available_memory (int): Memory available for KV cache in bytes. + Returns: + Tuple[KVCacheConfig, int]: The generated KVCacheConfig and the number of + GPU blocks. + """ + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} assert len(page_sizes) == 1 page_size = page_sizes.pop() @@ -373,7 +404,16 @@ def _get_kv_cache_config_same_type( def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, available_memory: int) -> Tuple[KVCacheConfig, int]: - """ returns the KV cache configuration and the number of GPU blocks + """ + Generates the KV cache configuration for a model + TODO: support hybrid models with more than one type of KV cache. + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of the model + available_memory (int): Memory available for KV cache in bytes. + Returns: + Tuple[KVCacheConfig, int]: The generated KVCacheConfig and the number of + GPU blocks. """ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) if is_same_type(kv_cache_spec): diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 7e40698d4668..f6a5fb2e399c 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -3,8 +3,8 @@ import weakref from collections import defaultdict from collections.abc import Sequence -from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar, - Union, overload) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List, + Optional, TypeVar, Union, overload) import torch @@ -12,6 +12,9 @@ from vllm.model_executor.models.utils import extract_layer_index from vllm.utils import get_mp_context, kill_process_tree +if TYPE_CHECKING: + from vllm.attention.layer import Attention + logger = init_logger(__name__) T = TypeVar("T") @@ -141,10 +144,20 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): def bind_kv_cache( - ctx: Dict[str, Any], + ctx: Dict[str, "Attention"], runner_kv_caches: List[torch.Tensor], kv_caches: Dict[str, torch.Tensor], ) -> None: + """ + Bind kv_caches to the forward context and model_runner's kv_cache. + Args: + ctx: The global forward context containing all Attention layers with + layer names as keys. + runner_kv_caches: The kv_cache declared by ModelRunner. + kv_caches: The allocated kv_caches with layer names as keys. + Returns: + None + """ # bind kv_caches to ModelRunner's kv_caches assert len(runner_kv_caches) == 0 index2name = defaultdict(list) @@ -154,11 +167,11 @@ def bind_kv_cache( for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] layer_name = layer_names[0] - assert all(kv_caches[n] is kv_caches[layer_name] for n in layer_names[1:]) + assert all(kv_caches[n] is kv_caches[layer_name] + for n in layer_names[1:]) runner_kv_caches.append(kv_caches[layer_name]) # bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): - # TODO: change [kv_cache] to kv_cache when dropping v0 which uses - # virtual engine to support PP. + # NOTE: Use list because of v0 PP virtual engine. ctx[layer_name].kv_cache = [kv_cache] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9e8182f59e54..635cdfb5e135 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -857,8 +857,16 @@ def capture_model(self) -> None: elapsed_time, cuda_graph_size / (1 << 30)) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Allocate the KV cache for the model based on the provided configuration. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ if len(kv_cache_config.groups) > 1: - raise NotImplementedError("Hybrid kv-cache groups are not supported yet.") + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") kv_caches: Dict[str, torch.Tensor] = {} @@ -882,6 +890,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.kv_caches, kv_caches) def get_kv_cache_spec(self) -> KVCacheSpec: + """ + Generates the KVCacheSpec by parsing the kv cache format of each + Attention module in the static forward context. + Returns: + KVCacheSpec: A dictionary mapping layer names to their KV cache + format. Layers that do not need KV cache are not included. + """ + forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size kv_cache_spec: KVCacheSpec = {} From eb37f0caac2858dee92b9e4c7eb598bb1f64081d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 14 Jan 2025 02:16:54 -0800 Subject: [PATCH 12/20] add more comments Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 2 +- vllm/v1/kv_cache_interface.py | 62 ++++++++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 46c60005fab0..b4811c1f7992 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -353,7 +353,7 @@ def is_same_type(kv_cache_spec: KVCacheSpec) -> bool: True if all layers have the same type, False otherwise. """ - layer_keys = set(layer.type_key for layer in kv_cache_spec.values()) + layer_keys = set(layer.type_id for layer in kv_cache_spec.values()) return len(layer_keys) == 1 diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4d592c21ef36..0be9b95192a9 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -12,18 +12,42 @@ @dataclass class KVCacheSpecBase: + """ + A base class for specifying the KV cache format of one layer. + """ + + # number of tokens in a block block_size: int @property - def type_key(self) -> str: - # TODO: add docstring + def type_id(self) -> str: + """ + The type identifier of this KV cache. + Return different strings for layers with different KV cache type (e.g., + different number of tokens like full attention vs sliding window + attention, different KV cache size per token like layers with different + number of heads) + Returns: + The type identifier of this KV cache. + """ raise NotImplementedError @property def page_size_bytes(self) -> int: + """ + The size of a page with `block_size` tokens in bytes. + Returns: + The page size + """ raise NotImplementedError def bytes_for_tokens(self, num_tokens: int) -> int: + """ + The KV cache size for `num_tokens` tokens in bytes. Returns the real + memory size after padding `num_tokens` to `block_size`. + Returns: + The KV cache size + """ raise NotImplementedError @@ -34,8 +58,8 @@ class FullAttentionSpec(KVCacheSpecBase): dtype: torch.dtype @property - def type_key(self) -> str: - return f"full_attention_{self.block_size}_{self.bytes_for_tokens(1)}" + def type_id(self) -> str: + return f"full_attention_{self.block_size}_{self.page_size_bytes}" @property def page_size_bytes(self) -> int: @@ -51,17 +75,33 @@ def bytes_for_tokens(self, num_tokens: int) -> int: @dataclass class KVCacheTensor: - size: int # in bytes + """ + A dataclass for specifying how the workers should initialize the KV cache + for a layer. Only contains the size of KV cache for that layer for now. Will + be extended to support multiple layers sharing the same memory pool. + """ + size: int # The size of KV cache Tensor in bytes @dataclass class KVCacheConfig: - # layer_name -> the kv_cache tensor configuration for the layer + """ + The KV cache configuration of a model. + """ + """layer_name -> how to initialize KV cache for that layer""" tensors: Dict[str, KVCacheTensor] - - # [group_id][layer_name in the group]. One group containing all - # layer_names if the Spec for kv_cache of all layers are the same + """ + A list of kv-cache groups. Each group includes a set of layers with + the same kv-cache spec, and the total page_size of layers inside a group + is same across all groups (as the KVCacheManager only supports allocating + one page size). For example: + 1. A model only uses full attention: one group with all layers in the model. + 2. (not implemented yet) A model with the same number of full attention + layers and sliding window attention layers: two groups, one for full + attention layers and one for sliding window attention layers. + 3. (not implemented yet) A model with 2 full attention layers and 4 sliding + window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). + """ groups: List[List[str]] - - # the KVCacheSpec of the model + """the KVCacheSpec of the model""" kv_cache_spec: KVCacheSpec From 88fd1b8991777b5131ba6415bde2b26e8d015052 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 14 Jan 2025 02:18:46 -0800 Subject: [PATCH 13/20] format Signed-off-by: Chen Zhang --- vllm/v1/kv_cache_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 0be9b95192a9..f760769de728 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -77,7 +77,7 @@ def bytes_for_tokens(self, num_tokens: int) -> int: class KVCacheTensor: """ A dataclass for specifying how the workers should initialize the KV cache - for a layer. Only contains the size of KV cache for that layer for now. Will + for a layer. Only contains the size of KV cache for that layer for now. Will be extended to support multiple layers sharing the same memory pool. """ size: int # The size of KV cache Tensor in bytes From 3493061e315690fa3118ce66a7c9c54e6667ad8c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 14 Jan 2025 02:25:47 -0800 Subject: [PATCH 14/20] update comment Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 2 +- vllm/v1/engine/core.py | 2 +- vllm/v1/executor/multiproc_executor.py | 3 +-- vllm/v1/executor/ray_executor.py | 2 +- vllm/v1/executor/uniproc_executor.py | 2 +- 5 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index b4811c1f7992..40ffce6f53cf 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -314,7 +314,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, available_memory: int): """ - Checks if there is enough memory available for the KV cache of at least one + Checks whether `available_memory` is enough for the KV cache of at least one request with the model's max_model_len. Args: vllm_config: The global VllmConfig diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a9c53c9e362e..129b8df8df4f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -72,7 +72,7 @@ def _initialize_kv_caches(self, vllm_config: VllmConfig) -> Tuple[int, int]: start = time.time() - # Get all kv cache tensor needed by the model + # Get all kv cache needed by the model kv_cache_spec = self.model_executor.get_kv_cache_spec() # Profiles the peak memory usage of the model to determine how much diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 0fbc67bd8509..659c16e3fc3d 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -113,8 +113,7 @@ def determine_available_memory(self) -> int: def get_kv_cache_spec(self) -> KVCacheSpec: """ - Get all kv cache tensor needed by the model by invoking the - underlying worker. + Get all kv cache needed by the model by invoking the underlying worker. """ kv_cache_specs = self.collective_rpc("get_kv_cache_spec") assert all(lc == kv_cache_specs[0] for lc in kv_cache_specs) diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index ea348aea89c9..a7f4ada363b8 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -237,7 +237,7 @@ def initialize(self, kv_cache_config: KVCacheConfig) -> None: def get_kv_cache_spec(self) -> KVCacheSpec: """ - Get all kv cache tensor needed by the model + Get all kv cache needed by the model This invokes `get_kv_cache_spec` on each worker and asserts that they are identical. The KVCacheSpec is then returned. diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index e836ed6ce14a..5f64ba5e00fc 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -57,7 +57,7 @@ def determine_available_memory(self) -> int: return self.worker.determine_available_memory() def get_kv_cache_spec(self) -> KVCacheSpec: - """Get all kv cache tensor needed by the model by invoking the + """Get all kv cache needed by the model by invoking the underlying worker. """ return self.worker.get_kv_cache_spec() From 105814a2032b98e5aa19df793f0d9d388e766535 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 14 Jan 2025 02:28:03 -0800 Subject: [PATCH 15/20] update comment Signed-off-by: Chen Zhang --- vllm/v1/kv_cache_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index f760769de728..0fab5754c823 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -44,7 +44,7 @@ def page_size_bytes(self) -> int: def bytes_for_tokens(self, num_tokens: int) -> int: """ The KV cache size for `num_tokens` tokens in bytes. Returns the real - memory size after padding `num_tokens` to `block_size`. + memory size after padding `num_tokens` to full blocks. Returns: The KV cache size """ From 9ff57d088d58e232a684ab1cba576b45a262ad36 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 14 Jan 2025 02:36:32 -0800 Subject: [PATCH 16/20] small updates Signed-off-by: Chen Zhang --- vllm/v1/kv_cache_interface.py | 7 +++---- vllm/v1/utils.py | 2 -- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 0fab5754c823..608c0a60a765 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,11 +1,10 @@ -import math from dataclasses import dataclass from typing import Dict, List import torch from vllm.logger import init_logger -from vllm.utils import get_dtype_size +from vllm.utils import cdiv, get_dtype_size logger = init_logger(__name__) @@ -67,7 +66,7 @@ def page_size_bytes(self) -> int: * get_dtype_size(self.dtype) def bytes_for_tokens(self, num_tokens: int) -> int: - return math.ceil(num_tokens / self.block_size) * self.page_size_bytes + return cdiv(num_tokens, self.block_size) * self.page_size_bytes KVCacheSpec = Dict[str, KVCacheSpecBase] @@ -94,7 +93,7 @@ class KVCacheConfig: A list of kv-cache groups. Each group includes a set of layers with the same kv-cache spec, and the total page_size of layers inside a group is same across all groups (as the KVCacheManager only supports allocating - one page size). For example: + pages of the same size). For example: 1. A model only uses full attention: one group with all layers in the model. 2. (not implemented yet) A model with the same number of full attention layers and sliding window attention layers: two groups, one for full diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index f6a5fb2e399c..a4201db7f56b 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -155,8 +155,6 @@ def bind_kv_cache( layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. kv_caches: The allocated kv_caches with layer names as keys. - Returns: - None """ # bind kv_caches to ModelRunner's kv_caches assert len(runner_kv_caches) == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 635cdfb5e135..21d23ca57d2b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -858,7 +858,7 @@ def capture_model(self) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ - Allocate the KV cache for the model based on the provided configuration. + Initialize KV cache based on `kv_cache_config`. Args: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer @@ -891,7 +891,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def get_kv_cache_spec(self) -> KVCacheSpec: """ - Generates the KVCacheSpec by parsing the kv cache format of each + Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: KVCacheSpec: A dictionary mapping layer names to their KV cache From 044876eba53a3ad5b9097bda818b11697da125d5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 16 Jan 2025 08:45:29 -0800 Subject: [PATCH 17/20] update comments and function names Signed-off-by: Chen Zhang --- tests/v1/test_utils.py | 35 ++---------------- vllm/v1/core/kv_cache_utils.py | 49 +++++++++++++------------- vllm/v1/engine/core.py | 5 +-- vllm/v1/executor/multiproc_executor.py | 4 +-- vllm/v1/executor/ray_executor.py | 2 +- vllm/v1/executor/uniproc_executor.py | 4 +-- vllm/v1/kv_cache_interface.py | 2 ++ vllm/v1/utils.py | 33 +++++++++++------ vllm/v1/worker/gpu_model_runner.py | 3 +- vllm/v1/worker/gpu_worker.py | 3 +- 10 files changed, 64 insertions(+), 76 deletions(-) diff --git a/tests/v1/test_utils.py b/tests/v1/test_utils.py index 74f702d615f9..ac773b611f40 100644 --- a/tests/v1/test_utils.py +++ b/tests/v1/test_utils.py @@ -21,7 +21,7 @@ def test_bind_kv_cache(): 'layers.3.self_attn': torch.zeros((1, )), } runner_kv_caches: List[torch.Tensor] = [] - bind_kv_cache(ctx, runner_kv_caches, kv_cache) + bind_kv_cache(kv_cache, ctx, runner_kv_caches) assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ 'layers.0.self_attn'] assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ @@ -51,7 +51,7 @@ def test_bind_kv_cache_non_attention(): } runner_kv_caches: List[torch.Tensor] = [] - bind_kv_cache(ctx, runner_kv_caches, kv_cache) + bind_kv_cache(kv_cache, ctx, runner_kv_caches) assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ 'model.layers.20.attn'] @@ -60,34 +60,3 @@ def test_bind_kv_cache_non_attention(): assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] - - -def test_bind_kv_cache_encoder_decoder(): - from vllm.attention import Attention, AttentionType - - # example from bart - ctx = { - 'encoder.layers.0.self_attn.attn': - Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), - 'decoder.layers.0.encoder_attn.attn': - Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), - 'decoder.layers.0.self_attn.attn': - Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), - } - - kv_cache_tensor = torch.zeros((1, )) - kv_cache = { - 'decoder.layers.0.encoder_attn.attn': kv_cache_tensor, - 'decoder.layers.0.self_attn.attn': kv_cache_tensor, - } - encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache - - runner_kv_caches: List[torch.Tensor] = [] - bind_kv_cache(ctx, runner_kv_caches, kv_cache) - assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache - assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[ - 'decoder.layers.0.encoder_attn.attn'] - assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[ - 'decoder.layers.0.self_attn.attn'] - - assert runner_kv_caches[0] is kv_cache_tensor diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 40ffce6f53cf..55e9799f7081 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -314,12 +314,14 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, available_memory: int): """ - Checks whether `available_memory` is enough for the KV cache of at least one - request with the model's max_model_len. + Checks whether `available_memory` is enough for the KV cache to hold at + least one request with the model's max_model_len. + Args: vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of the model available_memory (int): Memory available for KV cache in bytes. + Raises: ValueError: If there is not enough memory available for the KV cache. """ @@ -337,14 +339,14 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, if needed_memory > available_memory: raise ValueError( f"To serve at least one request with the models's max seq len " - f"({max_model_len}), ({needed_memory/1024/1024/1024} GB KV cache is" - f"needed, which is larger than the available KV Cache memory " - f"({available_memory/1024/1024/1024} GB). Try increasing " - f"`gpu_memory_utilization` or decreasing `max_model_len` when " - f"initializing the engine.") + f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV " + f"cache is needed, which is larger than the available KV cache " + f"memory ({available_memory/1024/1024/1024:.2f} GB). Try " + f"increasing `gpu_memory_utilization` or decreasing " + f"`max_model_len` when initializing the engine.") -def is_same_type(kv_cache_spec: KVCacheSpec) -> bool: +def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: """ Whether all layers in the given KVCacheSpec have the same type of KV cache. Args: @@ -357,7 +359,7 @@ def is_same_type(kv_cache_spec: KVCacheSpec) -> bool: return len(layer_keys) == 1 -def _get_kv_cache_config_same_type( +def _get_kv_cache_config_uniform_type( vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, available_memory: int) -> Tuple[KVCacheConfig, int]: """ @@ -376,34 +378,34 @@ def _get_kv_cache_config_same_type( assert len(page_sizes) == 1 page_size = page_sizes.pop() - num_gpu_blocks = int(available_memory // page_size // len(kv_cache_spec)) - num_gpu_blocks = max(num_gpu_blocks, 0) + num_blocks = int(available_memory // page_size // len(kv_cache_spec)) + num_blocks = max(num_blocks, 0) if vllm_config.cache_config.num_gpu_blocks_override is not None: num_gpu_blocks_override = \ vllm_config.cache_config.num_gpu_blocks_override logger.info( "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override + "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + num_blocks = num_gpu_blocks_override - logger.info("# GPU blocks: %d", num_gpu_blocks) + logger.info("# GPU blocks: %d", num_blocks) - per_layer_size = page_size * num_gpu_blocks + per_layer_size = page_size * num_blocks kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, tensors={ layer_name: KVCacheTensor(size=per_layer_size) for layer_name in kv_cache_spec }, groups=[[layer_name for layer_name in kv_cache_spec]], kv_cache_spec=kv_cache_spec) - return kv_cache_config, num_gpu_blocks + return kv_cache_config def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, - available_memory: int) -> Tuple[KVCacheConfig, int]: + available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model TODO: support hybrid models with more than one type of KV cache. @@ -412,14 +414,13 @@ def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, kv_cache_spec: The kv cache spec of the model available_memory (int): Memory available for KV cache in bytes. Returns: - Tuple[KVCacheConfig, int]: The generated KVCacheConfig and the number of - GPU blocks. + The generated KVCacheConfig """ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) - if is_same_type(kv_cache_spec): - # kv cache of all layers are the same, which is true for most models. + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for most models. # Allocate the same amount of memory for each layer. - return _get_kv_cache_config_same_type(vllm_config, kv_cache_spec, - available_memory) + return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory) else: raise NotImplementedError diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 129b8df8df4f..9610480fb490 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -80,8 +80,9 @@ def _initialize_kv_caches(self, availble_gpu_memory = self.model_executor.determine_available_memory() # Get the kv cache tensor size - kv_cache_config, num_gpu_blocks = get_kv_cache_config( - vllm_config, kv_cache_spec, availble_gpu_memory) + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, + availble_gpu_memory) + num_gpu_blocks = kv_cache_config.num_blocks num_cpu_blocks = 0 # Initialize kv cache and warmup the execution diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 659c16e3fc3d..a450a89f48e9 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -101,7 +101,7 @@ def initialize(self, kv_cache_config: KVCacheConfig) -> None: def determine_available_memory(self) -> int: """ - Determine the available memory for KV cache by invoking the + Determine the available memory (in bytes) for KV cacheby invoking the underlying worker. """ memory_sizes = self.collective_rpc("determine_available_memory") @@ -116,7 +116,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: Get all kv cache needed by the model by invoking the underlying worker. """ kv_cache_specs = self.collective_rpc("get_kv_cache_spec") - assert all(lc == kv_cache_specs[0] for lc in kv_cache_specs) + assert all(s == kv_cache_specs[0] for s in kv_cache_specs) return kv_cache_specs[0] def collective_rpc(self, diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index a7f4ada363b8..fd67fa223577 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -243,7 +243,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: they are identical. The KVCacheSpec is then returned. """ kv_cache_specs = self._run_workers("get_kv_cache_spec") - assert all(lc == kv_cache_specs[0] for lc in kv_cache_specs) + assert all(s == kv_cache_specs[0] for s in kv_cache_specs) return kv_cache_specs[0] def _run_workers( diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 5f64ba5e00fc..83295e485489 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -51,8 +51,8 @@ def _create_worker( ) def determine_available_memory(self) -> int: - """Determine the available memory for KV cache by invoking the - underlying worker. + """Determine the available memory (in bytes) for KV cache by invoking + the underlying worker. """ return self.worker.determine_available_memory() diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 608c0a60a765..fbeef5186be6 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -87,6 +87,8 @@ class KVCacheConfig: """ The KV cache configuration of a model. """ + """The number of KV cache blocks""" + num_blocks: int """layer_name -> how to initialize KV cache for that layer""" tensors: Dict[str, KVCacheTensor] """ diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index a4201db7f56b..2d7b62c47244 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -144,32 +144,45 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): def bind_kv_cache( - ctx: Dict[str, "Attention"], - runner_kv_caches: List[torch.Tensor], kv_caches: Dict[str, torch.Tensor], + forward_context: Dict[str, "Attention"], + runner_kv_caches: List[torch.Tensor], ) -> None: """ - Bind kv_caches to the forward context and model_runner's kv_cache. + Bind the allocated KV cache to both ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with + kv_caches. + 2) Associates each attention layer in the `forward_context` with its + corresponding KV cache in kv_caches. + Args: - ctx: The global forward context containing all Attention layers with - layer names as keys. + forward_context: The global forward context containing all Attention + layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. kv_caches: The allocated kv_caches with layer names as keys. """ - # bind kv_caches to ModelRunner's kv_caches + # Bind kv_caches to ModelRunner assert len(runner_kv_caches) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. index2name = defaultdict(list) for layer_name in kv_caches: index2name[extract_layer_index(layer_name)].append(layer_name) for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError layer_name = layer_names[0] - assert all(kv_caches[n] is kv_caches[layer_name] - for n in layer_names[1:]) runner_kv_caches.append(kv_caches[layer_name]) - # bind kv_caches to forward context + # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. - ctx[layer_name].kv_cache = [kv_cache] + forward_context[layer_name].kv_cache = [kv_cache] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 21d23ca57d2b..82854e7706f5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -886,8 +886,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError bind_kv_cache( + kv_caches, self.vllm_config.compilation_config.static_forward_context, - self.kv_caches, kv_caches) + self.kv_caches) def get_kv_cache_spec(self) -> KVCacheSpec: """ diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 81e9640b7b30..f477f7c384cb 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -117,7 +117,8 @@ def determine_available_memory(self) -> int: memory can be used for KV cache without OOMs. The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the free memory that can be used for KV cache + Then, it calculate the free memory that can be used for KV cache in + bytes. .. tip:: You may limit the usage of GPU memory From 62f2c0984430c37c85a3120afc814885bd12e900 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 16 Jan 2025 08:47:38 -0800 Subject: [PATCH 18/20] format Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 55e9799f7081..d7672f44f2f0 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -349,9 +349,11 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: """ Whether all layers in the given KVCacheSpec have the same type of KV cache. + Args: kv_cache_spec (KVCacheSpec): The KVCacheSpec of the model - Returns: + + Returns: True if all layers have the same type, False otherwise. """ @@ -359,19 +361,20 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: return len(layer_keys) == 1 -def _get_kv_cache_config_uniform_type( - vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, - available_memory: int) -> Tuple[KVCacheConfig, int]: +def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, + kv_cache_spec: KVCacheSpec, + available_memory: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one type of KV cache. Divide the available memory equally among all layers. + Args: vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of the model available_memory (int): Memory available for KV cache in bytes. + Returns: - Tuple[KVCacheConfig, int]: The generated KVCacheConfig and the number of - GPU blocks. + The generated KVCacheConfig """ page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} @@ -409,10 +412,12 @@ def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, """ Generates the KV cache configuration for a model TODO: support hybrid models with more than one type of KV cache. + Args: vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of the model available_memory (int): Memory available for KV cache in bytes. + Returns: The generated KVCacheConfig """ From 138a4acdd83500429e0952b6a43c15cf5ec29648 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 16 Jan 2025 09:02:39 -0800 Subject: [PATCH 19/20] format Signed-off-by: Chen Zhang --- vllm/v1/core/kv_cache_utils.py | 10 +++++----- vllm/v1/executor/multiproc_executor.py | 2 +- vllm/v1/executor/uniproc_executor.py | 4 ++-- vllm/v1/kv_cache_interface.py | 3 +++ 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index d7672f44f2f0..bab99fe37cae 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -320,7 +320,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of the model - available_memory (int): Memory available for KV cache in bytes. + available_memory: Memory available for KV cache in bytes. Raises: ValueError: If there is not enough memory available for the KV cache. @@ -351,9 +351,9 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: Whether all layers in the given KVCacheSpec have the same type of KV cache. Args: - kv_cache_spec (KVCacheSpec): The KVCacheSpec of the model + kv_cache_spec: The KVCacheSpec of the model - Returns: + Returns: True if all layers have the same type, False otherwise. """ @@ -371,7 +371,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, Args: vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of the model - available_memory (int): Memory available for KV cache in bytes. + available_memory: Memory available for KV cache in bytes. Returns: The generated KVCacheConfig @@ -416,7 +416,7 @@ def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, Args: vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of the model - available_memory (int): Memory available for KV cache in bytes. + available_memory: Memory available for KV cache in bytes. Returns: The generated KVCacheConfig diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index a450a89f48e9..c988469d772f 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -101,7 +101,7 @@ def initialize(self, kv_cache_config: KVCacheConfig) -> None: def determine_available_memory(self) -> int: """ - Determine the available memory (in bytes) for KV cacheby invoking the + Determine the available memory (in bytes) for KV cache by invoking the underlying worker. """ memory_sizes = self.collective_rpc("determine_available_memory") diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 83295e485489..57c296f0d3b2 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -57,8 +57,8 @@ def determine_available_memory(self) -> int: return self.worker.determine_available_memory() def get_kv_cache_spec(self) -> KVCacheSpec: - """Get all kv cache needed by the model by invoking the - underlying worker. + """Get all kv cache needed by the model by invoking the underlying + worker. """ return self.worker.get_kv_cache_spec() diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index fbeef5186be6..6d5cc32ffc5b 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -26,6 +26,7 @@ def type_id(self) -> str: different number of tokens like full attention vs sliding window attention, different KV cache size per token like layers with different number of heads) + Returns: The type identifier of this KV cache. """ @@ -35,6 +36,7 @@ def type_id(self) -> str: def page_size_bytes(self) -> int: """ The size of a page with `block_size` tokens in bytes. + Returns: The page size """ @@ -44,6 +46,7 @@ def bytes_for_tokens(self, num_tokens: int) -> int: """ The KV cache size for `num_tokens` tokens in bytes. Returns the real memory size after padding `num_tokens` to full blocks. + Returns: The KV cache size """ From 2aa75096d347a9c5a57d4abe0427b89c65db1181 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 16 Jan 2025 09:19:20 -0800 Subject: [PATCH 20/20] update docstring Signed-off-by: Chen Zhang --- vllm/v1/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 2d7b62c47244..8dfcf2dd7860 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -159,10 +159,10 @@ def bind_kv_cache( corresponding KV cache in kv_caches. Args: + kv_caches: The allocated kv_caches with layer names as keys. forward_context: The global forward context containing all Attention layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. - kv_caches: The allocated kv_caches with layer names as keys. """ # Bind kv_caches to ModelRunner assert len(runner_kv_caches) == 0