Skip to content
Open
2 changes: 1 addition & 1 deletion tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
# These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values.
self.slot_size_bytes = 4096
self.block_len = self.slot_size_bytes * self.block_size
self.block_len = self.slot_size_bytes * self.block_size * 2
self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks

Expand Down
4 changes: 3 additions & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def get_vllm_config():
def model_runner():
vllm_config = get_vllm_config()
model_config = vllm_config.model_config
# set to bf16, otherwise FlexAttention is always chosen
torch.set_default_dtype(torch.bfloat16)
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context[
Expand Down Expand Up @@ -405,7 +407,7 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
n_heads = model_runner.model_config.get_num_kv_heads(
model_runner.parallel_config)
expected_kv_cache_shape = [
2, NUM_BLOCKS, BLOCK_SIZE, n_heads,
NUM_BLOCKS, 2, BLOCK_SIZE, n_heads,
model_runner.model_config.get_head_size()
]
# TODO mla test
Expand Down
16 changes: 9 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
self._use_flashattn = attn_backend == _Backend.FLASH_ATTN_VLLM_V1
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
logger.debug("Detected attention backend %s", self.backend_name)

Expand Down Expand Up @@ -735,8 +736,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
if self._use_flashinfer or self._use_flashattn:
# FlashInfer and FlashAttn swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
Expand Down Expand Up @@ -772,13 +773,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
# Conversely for FlashInfer and FlashAttn, K and V are transferred
# in the same tensor to better exploit the memory layout (ie
# num_blocks is the first dim).
for cache_or_caches in xfer_buffers.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla \
or self._use_pallas_v1 or self._use_flashinfer \
else cache_or_caches
or self._use_flashattn else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
Expand Down Expand Up @@ -935,8 +937,8 @@ def add_remote_agent(self,
else:
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio)
if self._use_flashinfer:
# Account for joint KV in FlashInfer.
if self._use_flashinfer or self._use_flashattn:
# Account for joint KV in FlashInfer and FlashAttn.
remote_block_size //= 2

assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,14 @@ def inject_kv_into_layer(
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1)
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
num_pages, 2, page_size, -1)
slot_mapping_page = slot_mapping // page_size
slot_mapping_tok = slot_mapping % page_size
dst_kv_cache_layer[slot_mapping_page, :, slot_mapping_tok,
...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)

# Get the metadata
Expand Down Expand Up @@ -216,9 +219,12 @@ def extract_kv_from_layer(
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
num_pages, page_size = layer.shape[0], layer.shape[2]
slot_mapping_page = slot_mapping // page_size
slot_mapping_tok = slot_mapping % page_size
return layer.reshape(num_pages, 2, page_size,
-1)[slot_mapping_page, :, slot_mapping_tok,
...]

connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_kv_cache_shape(
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
return (num_blocks, 2, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
Expand Down Expand Up @@ -436,7 +436,7 @@ def forward(
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
Expand Down Expand Up @@ -479,7 +479,7 @@ def forward(
attn_metadata, layer)

# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
key_cache, value_cache = kv_cache.unbind(1)

if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
Expand Down