diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 43a27da2dbe4..1e2767e2d198 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -19,8 +19,7 @@ hash_request_tokens, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor, - SlidingWindowSpec) + KVCacheGroupSpec, KVCacheTensor) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -55,14 +54,12 @@ def new_kv_cache_spec(block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, - use_mla=False, - sliding_window=None): + use_mla=False): return FullAttentionSpec(block_size=block_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, - use_mla=use_mla, - sliding_window=sliding_window) + use_mla=use_mla) def test_none_hash(monkeypatch): @@ -495,68 +492,6 @@ def test_unify_kv_cache_configs(): unify_kv_cache_configs(diff_kv_cache_config) -def test_merge_kv_cache_spec(): - same_layer_specs = [ - new_kv_cache_spec(num_kv_heads=32), - new_kv_cache_spec(num_kv_heads=32), - ] - merged_layer_spec = same_layer_specs[0].merge(same_layer_specs) - assert merged_layer_spec.block_size == 16 - assert merged_layer_spec.num_kv_heads == 32 - assert merged_layer_spec.head_size == 64 - assert merged_layer_spec.dtype == torch.float32 - assert merged_layer_spec.sliding_window is None - - different_layer_specs = [ - new_kv_cache_spec(num_kv_heads=32), - new_kv_cache_spec(num_kv_heads=16), - ] - with pytest.raises(AssertionError): - different_layer_specs[0].merge(different_layer_specs) - - full_spec = new_kv_cache_spec(num_kv_heads=32) - different_type_layer_specs = [ - full_spec, - SlidingWindowSpec( - block_size=full_spec.block_size, - num_kv_heads=full_spec.num_kv_heads, - head_size=full_spec.head_size, - dtype=full_spec.dtype, - use_mla=full_spec.use_mla, - sliding_window=1, - ), - ] - with pytest.raises(AssertionError): - different_type_layer_specs[0].merge(different_type_layer_specs) - with pytest.raises(AssertionError): - different_type_layer_specs[1].merge(different_type_layer_specs) - - different_sliding_window_layer_specs = [ - new_kv_cache_spec(num_kv_heads=32), - new_kv_cache_spec(num_kv_heads=32, sliding_window=1), - new_kv_cache_spec(num_kv_heads=32, sliding_window=2), - ] - with pytest.raises(ValueError): - different_sliding_window_layer_specs[0].merge( - different_sliding_window_layer_specs) - - same_sliding_window_layer_specs = [ - new_kv_cache_spec(num_kv_heads=32, sliding_window=1), - new_kv_cache_spec(num_kv_heads=32, sliding_window=1), - ] - merged_layer_spec = same_sliding_window_layer_specs[0].merge( - same_sliding_window_layer_specs) - assert merged_layer_spec.sliding_window == 1 - - same_sliding_window_layer_spec_with_none = [ - new_kv_cache_spec(num_kv_heads=32, sliding_window=1), - new_kv_cache_spec(num_kv_heads=32, sliding_window=None), - ] - merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge( - same_sliding_window_layer_spec_with_none) - assert merged_layer_spec.sliding_window == 1 - - @pytest.mark.parametrize( ("model_id", "max_model_len", "want_estimated_max_len"), [ ("Qwen/Qwen1.5-7B", 16385, 16384), diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3da27786b1f2..2d7411381e16 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -84,7 +84,7 @@ def test_prefill(hash_algo): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] # Check full block metadata parent_block_hash = None @@ -107,13 +107,13 @@ def test_prefill(hash_algo): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == [5] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -141,13 +141,13 @@ def test_prefill(hash_algo): req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[6]] + assert blocks.get_block_ids() == [6] # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -171,7 +171,7 @@ def test_prefill(hash_algo): len(computed_blocks.blocks) * 16, computed_blocks) # This block ID order also checks the eviction order. - assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]] + assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -208,7 +208,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0_block_hashes = [b.block_hash for b in blocks.blocks] # Check full block metadata @@ -233,13 +233,13 @@ def test_prefill_plp(): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == [5] for block in computed_blocks.blocks: assert block.ref_cnt == 2 @@ -277,11 +277,11 @@ def test_prefill_plp(): block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks] == req0_block_hashes - assert block_ids != [[1, 2, 3, 4]] + assert block_ids != [1, 2, 3, 4] # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. - for block_id in block_ids[0]: + for block_id in block_ids: assert manager.block_pool.blocks[block_id].ref_cnt == 1 manager.free(req2) @@ -307,7 +307,7 @@ def test_decode(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -379,12 +379,12 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert computed_blocks.get_block_ids() == [[1, 2]] + assert computed_blocks.get_block_ids() == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[10]] + assert blocks.get_block_ids() == [10] assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -625,7 +625,7 @@ def test_mm_prefix_caching(): blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -686,7 +686,7 @@ def test_cache_key_salting(): blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -797,7 +797,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) blocks = manager.allocate_slots(req0, 55) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == [1, 2, 3, 4] unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -808,7 +808,7 @@ def test_reset_prefix_cache(): blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks) * 16, computed_blocks) - assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == [5] # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 638f5bedcfca..7b1359c8576f 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -9,11 +9,9 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState, + InputBatch) VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 @@ -24,27 +22,6 @@ MAX_NUM_PROMPT_TOKENS = 64 -def get_kv_cache_config() -> KVCacheConfig: - return KVCacheConfig( - num_blocks=10, - tensors={ - "layer.0": KVCacheTensor(size=1024), - }, - kv_cache_groups=[ - KVCacheGroupSpec( - layer_names=["layer.0"], - kv_cache_spec=FullAttentionSpec( - block_size=1, - num_kv_heads=1, - head_size=16, - dtype=torch.float16, - use_mla=False, - ), - ), - ], - ) - - def _compare_objs(obj1, obj2): attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) attr_names = set([ @@ -64,10 +41,6 @@ def _compare_objs(obj1, obj2): elif isinstance(a, np.ndarray): if np.allclose(a, b): is_same = True - elif isinstance(a, MultiGroupBlockTable): - for a_i, b_i in zip(a.block_tables, b.block_tables): - _compare_objs(a_i, b_i) - is_same = True elif isinstance(a, (BlockTable, SamplingMetadata)): _compare_objs(a, b) is_same = True # if we make it here must be same @@ -225,7 +198,7 @@ def _construct_cached_request_state(req_id_suffix: int): sampling_params=_create_sampling_params(), mm_inputs=[], mm_positions=[], - block_ids=[[]], + block_ids=[], generator=None, num_computed_tokens=len(output_token_ids), output_token_ids=output_token_ids, @@ -247,11 +220,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, + max_num_blocks_per_req=10, max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, - kv_cache_config=get_kv_cache_config(), ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -337,20 +310,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, + max_num_blocks_per_req=10, max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, - kv_cache_config=get_kv_cache_config(), ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, max_model_len=1024, + max_num_blocks_per_req=10, max_num_batched_tokens=1024, device=torch.device(device), pin_memory=is_pin_memory_available(), vocab_size=1024, - kv_cache_config=get_kv_cache_config(), ) reqs: list[CachedRequestState] = [] diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index e44660525763..725747294fd8 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,16 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 +import weakref import pytest +import torch -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VllmConfig) +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor) +from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -18,34 +17,13 @@ def initialize_kv_cache(runner: GPUModelRunner): """ Only perform necessary steps in GPUModelRunner.initialize_kv_cache() """ - kv_cache_config = KVCacheConfig( - num_blocks=10, - tensors={ - "layer.0": KVCacheTensor(size=1024), - }, - kv_cache_groups=[ - KVCacheGroupSpec( - layer_names=["layer.0"], - kv_cache_spec=FullAttentionSpec( - block_size=16, - num_kv_heads=runner.model_config.get_num_kv_heads( - runner.parallel_config), - head_size=runner.model_config.get_head_size(), - dtype=runner.kv_cache_dtype, - use_mla=False, - )) - ]) - runner.kv_cache_config = kv_cache_config - runner.input_batch = InputBatch( - max_num_reqs=runner.max_num_reqs, - max_model_len=runner.max_model_len, - max_num_batched_tokens=runner.max_num_tokens, - device=runner.device, - pin_memory=runner.pin_memory, - vocab_size=runner.model_config.get_vocab_size(), - kv_cache_config=kv_cache_config, - ) - runner.initialize_attn_backend(kv_cache_config) + kv_cache_spec = FullAttentionSpec(block_size=16, + num_kv_heads=1, + head_size=64, + dtype=torch.float16, + use_mla=False) + runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()( + weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table) @pytest.fixture @@ -70,12 +48,10 @@ def model_runner(): swap_space=0, cache_dtype="auto", ) - parallel_config = ParallelConfig() vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, scheduler_config=scheduler_config, - parallel_config=parallel_config, ) device = "cuda" @@ -97,7 +73,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), - block_ids=[[0]], + block_ids=[0], num_computed_tokens=0, lora_request=None, )) @@ -135,14 +111,13 @@ def _is_sampling_metadata_changed(model_runner, def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: req_index = model_runner.input_batch.req_id_to_index[req_id] - block_table = model_runner.input_batch.block_table[0] + block_table = model_runner.input_batch.block_table req_state = model_runner.requests[req_id] - if block_table.num_blocks_per_row[req_index] != len( - req_state.block_ids[0]): + if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): return False num_blocks = block_table.num_blocks_per_row[req_index] return (block_table.block_table_np[req_index, :num_blocks] == - req_state.block_ids[0]).all() + req_state.block_ids).all() def test_update_states_new_request(model_runner): @@ -225,7 +200,7 @@ def test_update_states_request_resumed(model_runner): req_id=req_id, resumed_from_preemption=False, new_token_ids=[], - new_block_ids=[[]], + new_block_ids=[], num_computed_tokens=0, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 0421a65a2c81..0fedb6fd5ed9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -288,7 +288,7 @@ def build_connector_meta( for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self._requests_need_load: meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], + block_ids=new_req.block_ids, block_size=self._block_size, is_store=False) total_need_load += 1 @@ -299,7 +299,7 @@ def build_connector_meta( # the original prompt tokens. if not self._found_match_for_request(new_req): meta.add_request(token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], + block_ids=new_req.block_ids, block_size=self._block_size, is_store=True) @@ -319,7 +319,7 @@ def build_connector_meta( # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. - block_ids = cached_req.new_block_ids[0] + block_ids = cached_req.new_block_ids meta.add_request(token_ids=token_ids, block_ids=block_ids, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 7ce39110ac01..3abb185c5b8f 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -67,13 +67,13 @@ def __init__(self, runner, kv_cache_spec: AttentionSpec, max_model_len = self.runner.model_config.max_model_len assert max_model_len == 32768,\ "AITER MLA requires max_model_len=32768" - assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ + assert self.runner.block_size == 1, "AITER MLA" \ "only supports block size 1." def _get_paged_kv_tensors( self, block_table: torch.Tensor, seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: - page_size = self.kv_cache_spec.block_size + page_size = self.runner.block_size block_table_bounds = (seq_lens + page_size - 1) // page_size mask = (torch.arange(block_table.size(1), diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index da18ece7555a..598fc871110e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -32,16 +32,9 @@ def create_empty(cls) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" return cls([]) - def get_block_ids(self) -> list[list[int]]: - """ - Converts the KVCacheBlocks instance to block_ids. - - Returns: - list[list[int]]: A two-level list where - * the outer list corresponds to KV cache groups (only 1 group now) - * each inner list contains the block_ids of the blocks in that group - """ - return [[block.block_id for block in self.blocks]] + def get_block_ids(self) -> list[int]: + """Converts the KVCacheBlocks instance to a list of block IDs.""" + return [block.block_id for block in self.blocks] def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" @@ -307,9 +300,9 @@ def get_num_common_prefix_blocks( self, request: Request, num_running_requests: int, - ) -> list[int]: + ) -> int: """Calculate the number of common prefix blocks shared by all requests - in the RUNNING state for each kv cache group. + in the RUNNING state. The function determines this by selecting any request and iterating through its blocks. A block is considered a common prefix block if its @@ -339,14 +332,11 @@ def get_num_common_prefix_blocks( requests in the current step. Returns: - list[int]: The number of common prefix blocks for each kv cache - group. + int: The number of common prefix blocks. """ assert request.status == RequestStatus.RUNNING - return [ - self.single_type_manager.get_num_common_prefix_blocks( - request.request_id, num_running_requests) - ] + return self.single_type_manager.get_num_common_prefix_blocks( + request.request_id, num_running_requests) def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. @@ -364,8 +354,10 @@ def take_events(self) -> list[KVCacheEvent]: """ return self.block_pool.take_events() - def get_block_ids(self, request_id: str) -> list[list[int]]: + def get_block_ids(self, request_id: str) -> list[int]: """Get the block ids of a request.""" assert request_id in self.single_type_manager.req_to_blocks - return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id] - ).get_block_ids() + return [ + block.block_id + for block in self.single_type_manager.req_to_blocks[request_id] + ] diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 403b5401be75..27c515835087 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -577,12 +577,14 @@ def create_kv_cache_group_specs( """ kv_cache_groups = [] for layer_names_one_group in grouped_layer_names: - layer_specs = [ - kv_cache_spec[layer_name] for layer_name in layer_names_one_group - ] - merged_layer_spec = layer_specs[0].merge(layer_specs) + layer_spec = kv_cache_spec[layer_names_one_group[0]] + assert all( + kv_cache_spec[layer_name] == layer_spec + for layer_name in layer_names_one_group[1:]), ( + "All layers in the same KV cache group must share the same " + "KVCacheSpec.") kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) + KVCacheGroupSpec(layer_names_one_group, layer_spec)) return kv_cache_groups @@ -681,7 +683,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): head_size=spec.head_size, dtype=spec.dtype, use_mla=spec.use_mla, - sliding_window=spec.sliding_window, ) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 257234430983..24032498e50b 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -26,7 +26,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams - block_ids: list[list[int]] + block_ids: list[int] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -34,7 +34,7 @@ class NewRequestData: def from_request( cls, request: Request, - block_ids: list[list[int]], + block_ids: list[int], ) -> NewRequestData: return cls( req_id=request.request_id, @@ -85,7 +85,7 @@ class CachedRequestData: # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_token_ids: list[int] - new_block_ids: list[list[int]] + new_block_ids: list[int] num_computed_tokens: int @classmethod @@ -94,7 +94,7 @@ def from_request( request: Request, resumed_from_preemption: bool, new_token_ids: list[int], - new_block_ids: list[list[int]], + new_block_ids: list[int], ) -> CachedRequestData: return cls( req_id=request.request_id, @@ -131,9 +131,9 @@ class SchedulerOutput: # E.g., if a request has [0, 1], it could mean the vision encoder needs # to process that the request's 0-th and 1-th images in the current step. scheduled_encoder_inputs: dict[str, list[int]] - # Number of common prefix blocks for all requests in each KV cache group. + # Number of common prefix blocks for all requests. # This can be used for cascade attention. - num_common_prefix_blocks: list[int] + num_common_prefix_blocks: int # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d8fd67e232cb..2152409019b9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -173,7 +173,7 @@ def schedule(self) -> SchedulerOutput: # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, list[list[int]]] = {} + req_to_new_block_ids: dict[str, list[int]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -484,8 +484,7 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) + num_common_prefix_blocks = 0 if self.running: any_request = self.running[0] num_common_prefix_blocks = ( @@ -572,7 +571,7 @@ def _make_cached_request_data( request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: list[list[int]], + new_block_ids: list[int], resumed_from_preemption: bool, ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating @@ -947,9 +946,7 @@ def _connector_finished( """ if self.connector is None: return False, None - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" - block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: @@ -966,10 +963,9 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ if request.request_id not in self.finished_recving_kv_req_ids: return False - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" + # Now that the blocks are ready, actually cache them. - block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) num_computed_tokens = len(block_ids) * self.block_size if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 2747fc7fabd1..4fc0844cd1f4 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,11 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -import copy from dataclasses import dataclass -from typing import Optional import torch -from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger @@ -56,16 +53,6 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: """ raise NotImplementedError - @classmethod - def merge(cls, specs: list[Self]) -> Self: - """ - Merge a list of KVCacheSpec objects into a single KVCacheSpec object. - """ - assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), ( - "All layers in the same KV cache group must share the same " - "type_id.") - return copy.deepcopy(specs[0]) - @dataclass class AttentionSpec(KVCacheSpec): @@ -84,16 +71,6 @@ def page_size_bytes(self) -> int: @dataclass class FullAttentionSpec(AttentionSpec): - sliding_window: Optional[int] = None - """ - When hybrid allocator is disabled and the model contains both full - attention layers and sliding window attention layers, sliding - window attention are regarded as full attention in KV cache manager - (blocks are allocated for all tokens), while computed as sliding window - attention in model runner. - In this case, we use FullAttentionSpec and record the sliding window size. - Default to None for not using sliding window attention. - """ @property def type_id(self) -> str: @@ -103,25 +80,6 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes - @classmethod - def merge(cls, specs: list[Self]) -> Self: - """ - Merge a list of FullAttentionSpec objects into a single - FullAttentionSpec object. - """ - merged_spec = super().merge(specs) - sliding_window = set(spec.sliding_window for spec in specs - if spec.sliding_window is not None) - if len(sliding_window) == 0: - merged_spec.sliding_window = None - elif len(sliding_window) == 1: - merged_spec.sliding_window = sliding_window.pop() - else: - raise ValueError( - "All sliding window layers in the same KV cache group " - "must have the same window size.") - return merged_spec - @dataclass class SlidingWindowSpec(AttentionSpec): diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 0c3341691509..581d3d9bd11b 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -4,8 +4,6 @@ import torch from vllm.logger import init_logger -from vllm.utils import cdiv -from vllm.v1.kv_cache_interface import KVCacheConfig logger = init_logger(__name__) @@ -98,48 +96,3 @@ def get_cpu_tensor(self) -> torch.Tensor: def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table_np - - -class MultiGroupBlockTable: - """The BlockTables for each KV cache group.""" - - def __init__(self, max_num_reqs: int, max_model_len: int, - max_num_batched_tokens: int, pin_memory: bool, - device: torch.device, kv_cache_config: KVCacheConfig) -> None: - max_num_blocks_per_req = [ - cdiv(max_model_len, g.kv_cache_spec.block_size) - for g in kv_cache_config.kv_cache_groups - ] - self.block_tables = [ - BlockTable(max_num_reqs, max_num_blocks_per_req[i], - max_num_batched_tokens, pin_memory, device) - for i in range(len(kv_cache_config.kv_cache_groups)) - ] - - def append_row(self, block_ids: list[list[int]], row_idx: int) -> None: - for i, block_table in enumerate(self.block_tables): - block_table.append_row(block_ids[i], row_idx) - - def add_row(self, block_ids: list[list[int]], row_idx: int) -> None: - for i, block_table in enumerate(self.block_tables): - block_table.add_row(block_ids[i], row_idx) - - def move_row(self, src: int, tgt: int) -> None: - for block_table in self.block_tables: - block_table.move_row(src, tgt) - - def swap_row(self, src: int, tgt: int) -> None: - for block_table in self.block_tables: - block_table.swap_row(src, tgt) - - def commit(self, num_reqs: int) -> None: - for block_table in self.block_tables: - block_table.commit(num_reqs) - - def clear(self) -> None: - for block_table in self.block_tables: - block_table.clear() - - def __getitem__(self, idx: int) -> "BlockTable": - """Returns the BlockTable for the i-th KV cache group.""" - return self.block_tables[idx] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 570de9bddd29..871654fca366 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,11 +11,10 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values -from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import MultiGroupBlockTable +from vllm.v1.worker.block_table import BlockTable _SAMPLING_EPS = 1e-5 @@ -30,7 +29,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: list[list[int]] + block_ids: list[int] num_computed_tokens: int output_token_ids: list[int] @@ -59,14 +58,15 @@ def __init__( self, max_num_reqs: int, max_model_len: int, + max_num_blocks_per_req: int, max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, - kv_cache_config: KVCacheConfig, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.device = device self.pin_memory = pin_memory @@ -99,13 +99,12 @@ def __init__( self.num_computed_tokens_cpu_tensor.numpy() # Block table. - self.block_table = MultiGroupBlockTable( + self.block_table = BlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len, + max_num_blocks_per_req=max_num_blocks_per_req, max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, - kv_cache_config=kv_cache_config, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 201796c96ee5..e26f97d816ae 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,8 +12,6 @@ import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadataBuilder) from vllm.attention.layer import Attention from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import (CompilationLevel, VllmConfig, @@ -34,8 +32,8 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, cdiv, check_use_alibi, - is_pin_memory_available) + GiB_bytes, LayerBlockType, LazyLoader, cdiv, + check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget @@ -53,7 +51,6 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -105,17 +102,59 @@ def __init__( self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] + # NOTE(woosuk): sliding_window is None for models with interleaved + # attention. Use interleaved_sliding_window instead. + self.sliding_window = model_config.get_sliding_window() + self.interleaved_sliding_window = getattr( + model_config.hf_text_config, "interleaved_sliding_window", None) + self.window_size = (self.sliding_window + or self.interleaved_sliding_window) + self.is_multimodal_model = model_config.is_multimodal_model + self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len + self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) self.num_query_heads = model_config.get_num_attention_heads( parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size + self.attn_backend = get_attn_backend( + self.head_size, + self.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.model_config.use_mla, + ) + if self.attn_backend is None: + error_msg = ( + f"Error with get_att_backend: {self.head_size=}, " + f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{self.model_config.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 GPUModelRunner.") + + if self.vllm_config.compilation_config.full_cuda_graph: + attn_backend_name = self.attn_backend.__name__ + flash_attn_version = get_flash_attn_version() + if attn_backend_name != "FlashAttentionBackend" or \ + flash_attn_version != 3: + raise ValueError( + f"full_cuda_graph is only supported with " + f"FA3. Current attention backend is {attn_backend_name}, " + f"FlashAttention version is {flash_attn_version}.") + self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -137,10 +176,8 @@ def __init__( # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] - self.attn_metadata_builders: list[AttentionMetadataBuilder] = [] - self.attn_backends: list[type[AttentionBackend]] = [] # self.kv_cache_config: KVCacheConfig - # self.input_batch: InputBatch # Persistent batch. + # self.attn_metadata_builder: type[AttentionMetadataBuilder] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -169,6 +206,16 @@ def __init__( # Request states. self.requests: dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=model_config.get_vocab_size(), + ) self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE @@ -263,31 +310,6 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: - """ - Update the order of requests in the batch based on the attention - backend's needs. For example, some attention backends (namely MLA) may - want to separate requests based on if the attention computation will be - compute-bound or memory-bound. - - Args: - scheduler_output: The scheduler output. - - Returns: - True if the batch was reordered, False otherwise. - """ - batch_reordered = self.attn_metadata_builders[0].reorder_batch( - self.input_batch, scheduler_output) - - # For models with multiple KV cache groups, the groups should agree on - # the same order of requests. We ensure this by only allowing the first - # group to reorder the batch and asserting that all other groups do not - # reorder the batch. - for i in range(1, len(self.kv_cache_config.kv_cache_groups)): - assert not self.attn_metadata_builders[i].reorder_batch( - self.input_batch, scheduler_output) - return batch_reordered - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -424,8 +446,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. - for i in range(len(self.kv_cache_config.kv_cache_groups)): - req_state.block_ids[i].extend(req_data.new_block_ids[i]) + req_state.block_ids.extend(req_data.new_block_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -483,7 +504,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if removed_req_indices: self.input_batch.condense(removed_req_indices) - batch_reordered = self._may_reorder_batch(scheduler_output) + # Some attention backends (namely MLA) may want to separate requests + # based on if the attention computation will be compute-bound or + # memory-bound. This gives them a hook to do that. + batch_reordered = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() @@ -551,29 +576,21 @@ def _prepare_inputs( torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping for each KV cache group. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - block_size = kv_cache_group_spec.kv_cache_spec.block_size - block_table: BlockTable = self.input_batch.block_table[ - kv_cache_group_id] - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. - block_table_indices = ( - req_indices * block_table.max_num_blocks_per_req + - positions_np // block_size) - block_table_cpu = block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() - block_offsets = positions_np % block_size - np.add( - block_numbers * block_size, - block_offsets, - out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.input_batch.block_table. + slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -615,6 +632,10 @@ def _prepare_inputs( attn_metadata: dict[str, FlashAttentionMetadata] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -623,19 +644,15 @@ def _prepare_inputs( if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id], - kv_cache_group_spec.kv_cache_spec, - self.attn_metadata_builders[kv_cache_group_id], + scheduler_output.num_common_prefix_blocks, ) - attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id].build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata)) + attn_metadata_i = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -673,8 +690,6 @@ def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, num_common_prefix_blocks: int, - kv_cache_spec: KVCacheSpec, - attn_metadata_builder: AttentionMetadataBuilder, ) -> int: """Compute the length of the common prefix for cascade attention. @@ -693,7 +708,7 @@ def _compute_cascade_attn_prefix_len( Returns: int: Length of common prefix in tokens. """ - common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size + common_prefix_len = num_common_prefix_blocks * self.block_size if common_prefix_len == 0: # Common case. return 0 @@ -742,19 +757,15 @@ def _compute_cascade_attn_prefix_len( common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * - kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - assert isinstance(kv_cache_spec, AttentionSpec) - use_cascade = attn_metadata_builder.use_cascade_attention( + common_prefix_len = (common_prefix_len // self.block_size * + self.block_size) + use_cascade = self.attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, num_query_heads=self.num_query_heads, - num_kv_heads=kv_cache_spec.num_kv_heads, + num_kv_heads=self.num_kv_heads, use_alibi=self.use_alibi, - use_sliding_window=use_sliding_window, + use_sliding_window=self.window_size is not None, num_sms=self.num_sms, ) return common_prefix_len if use_cascade else 0 @@ -1640,7 +1651,7 @@ def _dummy_run( dtype=np.int32) if skip_attn: - attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None + attn_metadata = None else: query_start_loc = self.query_start_loc[:num_reqs + 1] seq_lens = self.seq_lens[:num_reqs] @@ -1648,19 +1659,13 @@ def _dummy_run( common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens) - attn_metadata = {} - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id].build( - num_reqs=num_tokens, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - )) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_tokens, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1890,56 +1895,6 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: - """ - Initialize the attention backends and attention metadata builders. - """ - assert len(self.attn_backends) == 0 and len( - self.attn_metadata_builders - ) == 0, "Attention backends are already initialized" - for i, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if not isinstance(kv_cache_spec, AttentionSpec): - raise NotImplementedError( - "Only AttentionSpec is supported for now.") - attn_backend_i = get_attn_backend( - kv_cache_spec.head_size, - self.dtype, - kv_cache_spec.dtype, - kv_cache_spec.block_size, - self.model_config.is_attention_free, - use_mla=kv_cache_spec.use_mla, - ) - if attn_backend_i is None: - error_msg = ( - f"Error with get_attn_backend: {kv_cache_spec.head_size=}, " - f"{self.dtype=}, {kv_cache_spec.dtype=}, " - f"{kv_cache_spec.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{kv_cache_spec.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 " - "GPUModelRunner.") - - if self.vllm_config.compilation_config.full_cuda_graph: - attn_backend_name = attn_backend_i.__name__ - flash_attn_version = get_flash_attn_version() - if attn_backend_name != "FlashAttentionBackend" or \ - flash_attn_version != 3: - raise ValueError( - f"full_cuda_graph is only supported with " - f"FA3. Current attention backend is " - f"{attn_backend_name}, FlashAttention version is " - f"{flash_attn_version}.") - - block_table_i = self.input_batch.block_table[i] - attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - weakref.proxy(self), kv_cache_spec, block_table_i) - self.attn_backends.append(attn_backend_i) - self.attn_metadata_builders.append(attn_metadata_builder_i) - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1947,21 +1902,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + if len(kv_cache_config.kv_cache_groups) > 1: + raise NotImplementedError( + "Hybrid models with more than one KV cache type are not " + "supported yet.") self.kv_cache_config = kv_cache_config - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - kv_cache_config=kv_cache_config, - ) - self.initialize_attn_backend(kv_cache_config) kv_caches: dict[str, torch.Tensor] = {} - for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] @@ -1976,7 +1925,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( + kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype @@ -1996,6 +1945,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self), + kv_cache_config.kv_cache_groups[0].kv_cache_spec, + self.input_batch.block_table) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2da99696445e..b4daf5a34678 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -171,10 +171,19 @@ def __init__( self.kv_caches: list[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} - # self.input_batch: InputBatch # Persistent batch. # Request states. self.requests: dict[str, CachedRequestState] = {} + # Persistent batch. + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_blocks_per_req=self.max_num_blocks_per_req, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.vocab_size, + ) # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. @@ -190,7 +199,7 @@ def __init__( self.block_table_cpu = torch.zeros( (self.max_num_reqs, self.max_num_blocks_per_req), - dtype=torch.int32, + dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu") self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, @@ -515,12 +524,12 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, - out=self.input_batch.block_table[0]. + out=self.input_batch.block_table. slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. @@ -545,15 +554,15 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) - self.input_batch.block_table[0].slot_mapping_cpu[ + self.input_batch.block_table.slot_mapping_cpu[ total_num_scheduled_tokens:] = _PAD_SLOT_ID slot_mapping = ( - self.input_batch.block_table[0]. + self.input_batch.block_table. slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( self.device)) block_tables = self.block_table_cpu[:self.max_num_reqs] block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) + self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) block_tables = block_tables.to(self.device) query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( self.device) @@ -1254,18 +1263,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: "Hybrid models with more than one KV cache type are not " "supported yet.") - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - kv_cache_config=kv_cache_config, - ) - assert self.block_table_cpu.dtype == self.input_batch.block_table[ - 0].get_cpu_tensor().dtype - kv_caches: dict[str, torch.Tensor] = {} for kv_cache_group in kv_cache_config.kv_cache_groups: