Skip to content

[Nixl] Heterogeneous TP support FlashInfer #20189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
Expand Down Expand Up @@ -613,6 +612,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
logger.debug("Done registering descs")
self._registered_descs.append(descs)

if self._use_flashinfer:
# NOTE (NickLucche) When FlahsInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
# registerMem allowing faster descs queries. In order to be able to
# split on kv_heads dim as required by heterogeneous TP, one must
# be able to index K/V separately. Hence the we double the number
# of 'virtual' regions here and double the descs below.
self.num_regions *= 2

block_len = self.get_backend_aware_block_len()
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
Expand All @@ -625,7 +634,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, self.tp_rank))
blocks_data.append((addr, block_len, self.tp_rank))

if self._use_flashinfer:
# To maintain the same descs ordering, K/V must be interleaved.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# Register addresses for V cache.
v_addr = addr + block_len
blocks_data.append((v_addr, block_len, self.tp_rank))
Comment on lines +639 to +646
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for creating blocks_data for FlashInfer seems to be incorrect. The preceding loop (outside this diff) adds all K blocks, and this new block adds all V blocks. This results in a [K0..Kn, V0..Vn] layout for each layer's descriptors.

However, the existing logic in _get_block_descs_ids for FlashInfer expects an interleaved layout of [K0, V0, K1, V1, ...]. This mismatch will likely cause incorrect data to be transferred.

To fix this, the descriptor creation for K and V should be interleaved within a single loop. You'll need to modify the loop at line 631 to handle both FlashInfer and other backends correctly.

For example (conceptual):

if self._use_flashinfer:
    # Interleave K and V block registrations
    for block_id in range(self.num_blocks):
        # ... calculate addr for K
        blocks_data.append(...) # K
        # ... calculate addr for V
        blocks_data.append(...) # V
else:
    # Original logic for other backends
    for block_id in range(self.num_blocks):
        # ...
        blocks_data.append(...)

logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)

Expand Down Expand Up @@ -754,7 +772,8 @@ def add_remote_agent(self,
# Only register the remote's descriptors if current rank pulls from it.
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
block_len = self.get_backend_aware_block_len()
rank_offset = self.tp_rank % tp_ratio * block_len \
if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
Expand All @@ -765,7 +784,16 @@ def add_remote_agent(self,
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
blocks_data.append((addr, block_len, remote_tp_rank))

if self._use_flashinfer:
# When FlashInfer is used, index the Vs separately.
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_len // 2
blocks_data.append((v_addr, block_len, remote_tp_rank))
Comment on lines +789 to +795
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the issue in register_kv_caches, the logic here for creating remote block descriptors for FlashInfer seems incorrect. The preceding loop adds K blocks, and this new block adds V blocks, resulting in a non-interleaved [K0..Kn, V0..Vn] layout.

This contradicts the interleaved layout [K0, V0, K1, V1, ...] expected by _get_block_descs_ids for FlashInfer, which will likely lead to incorrect data transfers.

The fix would be to interleave the K and V block descriptor creation within a single loop, similar to the suggestion for register_kv_caches.


logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
Expand Down Expand Up @@ -1079,6 +1107,14 @@ def _get_block_descs_ids(self,
descs_ids.append(reg_id * num_blocks + block_id)
return descs_ids

def get_backend_aware_block_len(self):
if self._use_flashinfer:
# For indexing only half (either just the K or V part).
block_len = self.block_len // 2
else:
block_len = self.block_len
return block_len


@contextlib.contextmanager
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
Expand Down