diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7a077dce770..1102935e5ea 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -537,7 +537,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_rank = 2 # [block_size, latent_dim] block_shape = first_kv_cache.shape[-block_rank:] block_size, kv_latent_dim = block_shape - self.slot_size_bytes = kv_elem_size * kv_latent_dim else: # [2 (k and v), num_blocks, ...] if self._use_flashinfer: @@ -613,6 +612,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("Done registering descs") self._registered_descs.append(descs) + if self._use_flashinfer: + # NOTE (NickLucche) When FlahsInfer is used, memory is registered + # with joint KV for each block. This minimizes the overhead in + # registerMem allowing faster descs queries. In order to be able to + # split on kv_heads dim as required by heterogeneous TP, one must + # be able to index K/V separately. Hence the we double the number + # of 'virtual' regions here and double the descs below. + self.num_regions *= 2 + + block_len = self.get_backend_aware_block_len() # Register local/src descr for NIXL xfer. blocks_data = [] for base_addr in self.kv_caches_base_addr[self.engine_id]: @@ -625,7 +634,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_offset = block_id * self.block_len addr = base_addr + block_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, self.tp_rank)) + blocks_data.append((addr, block_len, self.tp_rank)) + + if self._use_flashinfer: + # To maintain the same descs ordering, K/V must be interleaved. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + addr = base_addr + block_offset + # Register addresses for V cache. + v_addr = addr + block_len + blocks_data.append((v_addr, block_len, self.tp_rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) @@ -754,7 +772,8 @@ def add_remote_agent(self, # Only register the remote's descriptors if current rank pulls from it. self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr - rank_offset = self.tp_rank % tp_ratio * self.block_len \ + block_len = self.get_backend_aware_block_len() + rank_offset = self.tp_rank % tp_ratio * block_len \ if not (self.use_mla or is_kv_replicated) else 0 # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: @@ -765,7 +784,16 @@ def add_remote_agent(self, # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, remote_tp_rank)) + blocks_data.append((addr, block_len, remote_tp_rank)) + + if self._use_flashinfer: + # When FlashInfer is used, index the Vs separately. + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + addr = base_addr + block_offset + rank_offset + v_addr = addr + nixl_agent_meta.block_len // 2 + blocks_data.append((v_addr, block_len, remote_tp_rank)) + logger.debug( "Created %s blocks for dst engine %s with remote rank %s and " "local rank %s", len(blocks_data), engine_id, remote_tp_rank, @@ -1079,6 +1107,14 @@ def _get_block_descs_ids(self, descs_ids.append(reg_id * num_blocks + block_id) return descs_ids + def get_backend_aware_block_len(self): + if self._use_flashinfer: + # For indexing only half (either just the K or V part). + block_len = self.block_len // 2 + else: + block_len = self.block_len + return block_len + @contextlib.contextmanager def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: