diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index c5ca7df83685..0d39b96fd467 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -217,7 +217,7 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int, # These should've been done in register_kv_caches(), called by # gpu_model_runner. Here we just hardcode some dummy values. self.slot_size_bytes = 4096 - self.block_len = self.slot_size_bytes * self.block_size + self.block_len = self.slot_size_bytes * self.block_size * 2 self.num_blocks = 1 self.dst_num_blocks[self.engine_id] = self.num_blocks diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index e151d388c293..2e4ab11c5922 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -99,6 +99,8 @@ def get_vllm_config(): def model_runner(): vllm_config = get_vllm_config() model_config = vllm_config.model_config + # set to bf16, otherwise FlexAttention is always chosen + torch.set_default_dtype(torch.bfloat16) num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) head_size = model_config.get_head_size() vllm_config.compilation_config.static_forward_context[ @@ -405,7 +407,7 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): n_heads = model_runner.model_config.get_num_kv_heads( model_runner.parallel_config) expected_kv_cache_shape = [ - 2, NUM_BLOCKS, BLOCK_SIZE, n_heads, + NUM_BLOCKS, 2, BLOCK_SIZE, n_heads, model_runner.model_config.get_head_size() ] # TODO mla test diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e7fc2b118145..b02caef31054 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -545,6 +545,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.backend_name = backend.get_name() attn_backend = backend_name_to_enum(self.backend_name) self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 + self._use_flashattn = attn_backend == _Backend.FLASH_ATTN_VLLM_V1 self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 logger.debug("Detected attention backend %s", self.backend_name) @@ -735,8 +736,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.slot_size_bytes = kv_elem_size * kv_latent_dim else: # [2 (k and v), num_blocks, ...] - if self._use_flashinfer: - # FlashInfer swaps 2<->num_blocks dimensions. + if self._use_flashinfer or self._use_flashattn: + # FlashInfer and FlashAttn swaps 2<->num_blocks dimensions. self.num_blocks = first_kv_cache.shape[0] block_rank = 4 # [2, block_size, kv_heads, head_dim] else: @@ -772,13 +773,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # are non-contiguous (it's not locally guaranteed that they will be) # Disadvantage is that the encoded NixlAgentMetadata is now larger # (roughly 8KB vs 5KB). - # Conversely for FlashInfer, K and V are transferred in the same tensor - # to better exploit the memory layout (ie num_blocks is the first dim). + # Conversely for FlashInfer and FlashAttn, K and V are transferred + # in the same tensor to better exploit the memory layout (ie + # num_blocks is the first dim). for cache_or_caches in xfer_buffers.values(): # Normalize to always be a list of caches cache_list = [cache_or_caches] if use_mla \ or self._use_pallas_v1 or self._use_flashinfer \ - else cache_or_caches + or self._use_flashattn else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len @@ -935,8 +937,8 @@ def add_remote_agent(self, else: remote_block_size = nixl_agent_meta.block_len // ( self.slot_size_bytes * tp_ratio) - if self._use_flashinfer: - # Account for joint KV in FlashInfer. + if self._use_flashinfer or self._use_flashattn: + # Account for joint KV in FlashInfer and FlashAttn. remote_block_size //= 2 assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index fd79387269d5..a1e42b0b07af 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -130,11 +130,14 @@ def inject_kv_into_layer( dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) else: - num_pages = dst_kv_cache_layer_shape[1] + num_pages = dst_kv_cache_layer_shape[0] page_size = dst_kv_cache_layer_shape[2] dst_kv_cache_layer = dst_kv_cache_layer.reshape( - 2, num_pages * page_size, -1) - dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + num_pages, 2, page_size, -1) + slot_mapping_page = slot_mapping // page_size + slot_mapping_tok = slot_mapping % page_size + dst_kv_cache_layer[slot_mapping_page, :, slot_mapping_tok, + ...] = src_kv_cache dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) # Get the metadata @@ -216,9 +219,12 @@ def extract_kv_from_layer( num_pages, page_size = layer.shape[0], layer.shape[1] return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] - num_pages, page_size = layer.shape[1], layer.shape[2] - return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - ...] + num_pages, page_size = layer.shape[0], layer.shape[2] + slot_mapping_page = slot_mapping // page_size + slot_mapping_tok = slot_mapping % page_size + return layer.reshape(num_pages, 2, page_size, + -1)[slot_mapping_page, :, slot_mapping_tok, + ...] connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, SharedStorageConnectorMetadata) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 95ba56b35937..0a987f4f9f80 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -85,7 +85,7 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return (num_blocks, 2, block_size, num_kv_heads, head_size) @staticmethod def get_kv_cache_stride_order() -> tuple[int, ...]: @@ -436,7 +436,7 @@ def forward( query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -479,7 +479,7 @@ def forward( attn_metadata, layer) # For decoder and cross-attention, use KV cache as before - key_cache, value_cache = kv_cache.unbind(0) + key_cache, value_cache = kv_cache.unbind(1) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache.