-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Nixl] Heterogeneous TP support FlashInfer #20189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -537,7 +537,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
block_rank = 2 # [block_size, latent_dim] | ||
block_shape = first_kv_cache.shape[-block_rank:] | ||
block_size, kv_latent_dim = block_shape | ||
self.slot_size_bytes = kv_elem_size * kv_latent_dim | ||
else: | ||
# [2 (k and v), num_blocks, ...] | ||
if self._use_flashinfer: | ||
|
@@ -613,6 +612,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
logger.debug("Done registering descs") | ||
self._registered_descs.append(descs) | ||
|
||
if self._use_flashinfer: | ||
# NOTE (NickLucche) When FlahsInfer is used, memory is registered | ||
# with joint KV for each block. This minimizes the overhead in | ||
# registerMem allowing faster descs queries. In order to be able to | ||
# split on kv_heads dim as required by heterogeneous TP, one must | ||
# be able to index K/V separately. Hence the we double the number | ||
# of 'virtual' regions here and double the descs below. | ||
self.num_regions *= 2 | ||
|
||
block_len = self.get_backend_aware_block_len() | ||
# Register local/src descr for NIXL xfer. | ||
blocks_data = [] | ||
for base_addr in self.kv_caches_base_addr[self.engine_id]: | ||
|
@@ -625,7 +634,16 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | |
block_offset = block_id * self.block_len | ||
addr = base_addr + block_offset | ||
# (addr, len, device id) | ||
blocks_data.append((addr, self.block_len, self.tp_rank)) | ||
blocks_data.append((addr, block_len, self.tp_rank)) | ||
|
||
if self._use_flashinfer: | ||
# To maintain the same descs ordering, K/V must be interleaved. | ||
for block_id in range(self.num_blocks): | ||
block_offset = block_id * self.block_len | ||
addr = base_addr + block_offset | ||
# Register addresses for V cache. | ||
v_addr = addr + block_len | ||
blocks_data.append((v_addr, block_len, self.tp_rank)) | ||
logger.debug("Created %s blocks for src engine %s and rank %s", | ||
len(blocks_data), self.engine_id, self.tp_rank) | ||
|
||
|
@@ -754,7 +772,8 @@ def add_remote_agent(self, | |
# Only register the remote's descriptors if current rank pulls from it. | ||
self.kv_caches_base_addr[ | ||
engine_id] = nixl_agent_meta.kv_caches_base_addr | ||
rank_offset = self.tp_rank % tp_ratio * self.block_len \ | ||
block_len = self.get_backend_aware_block_len() | ||
rank_offset = self.tp_rank % tp_ratio * block_len \ | ||
if not (self.use_mla or is_kv_replicated) else 0 | ||
# Register all remote blocks, but only the corresponding kv heads. | ||
for base_addr in nixl_agent_meta.kv_caches_base_addr: | ||
|
@@ -765,7 +784,16 @@ def add_remote_agent(self, | |
# self.block_len == remote_block_len//tp_ratio bytes. | ||
addr = base_addr + block_offset + rank_offset | ||
# (addr, len, device id) | ||
blocks_data.append((addr, self.block_len, remote_tp_rank)) | ||
blocks_data.append((addr, block_len, remote_tp_rank)) | ||
|
||
if self._use_flashinfer: | ||
# When FlashInfer is used, index the Vs separately. | ||
for block_id in range(nixl_agent_meta.num_blocks): | ||
block_offset = block_id * nixl_agent_meta.block_len | ||
addr = base_addr + block_offset + rank_offset | ||
v_addr = addr + nixl_agent_meta.block_len // 2 | ||
blocks_data.append((v_addr, block_len, remote_tp_rank)) | ||
Comment on lines
+789
to
+795
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the issue in This contradicts the interleaved layout The fix would be to interleave the K and V block descriptor creation within a single loop, similar to the suggestion for |
||
|
||
logger.debug( | ||
"Created %s blocks for dst engine %s with remote rank %s and " | ||
"local rank %s", len(blocks_data), engine_id, remote_tp_rank, | ||
|
@@ -1079,6 +1107,14 @@ def _get_block_descs_ids(self, | |
descs_ids.append(reg_id * num_blocks + block_id) | ||
return descs_ids | ||
|
||
def get_backend_aware_block_len(self): | ||
if self._use_flashinfer: | ||
# For indexing only half (either just the K or V part). | ||
block_len = self.block_len // 2 | ||
else: | ||
block_len = self.block_len | ||
return block_len | ||
|
||
|
||
@contextlib.contextmanager | ||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for creating
blocks_data
for FlashInfer seems to be incorrect. The preceding loop (outside this diff) adds all K blocks, and this new block adds all V blocks. This results in a[K0..Kn, V0..Vn]
layout for each layer's descriptors.However, the existing logic in
_get_block_descs_ids
for FlashInfer expects an interleaved layout of[K0, V0, K1, V1, ...]
. This mismatch will likely cause incorrect data to be transferred.To fix this, the descriptor creation for K and V should be interleaved within a single loop. You'll need to modify the loop at line 631 to handle both FlashInfer and other backends correctly.
For example (conceptual):