-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Nixl] Heterogeneous TP support FlashInfer #20189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Nixl] Heterogeneous TP support FlashInfer #20189
Conversation
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @NickLucche, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces support for FlashInfer in a heterogeneous Tensor Parallel (TP) environment when utilizing the NixlConnector. The primary goal is to reconcile the differing KV cache memory layouts between FlashInfer and FlashAttention, enabling efficient and consistent block-level data transfer for distributed inference by introducing a virtual splitting mechanism for FlashInfer's KV cache.
Highlights
- FlashInfer Heterogeneous TP Support: Enabled FlashInfer to work efficiently in heterogeneous Tensor Parallel (TP) settings by addressing its distinct KV cache memory layout compared to FlashAttention. This ensures consistent block-level data transfer for distributed inference.
- Virtual KV Cache Splitting: Implemented a 'virtual split' mechanism for FlashInfer's KV cache. This involves doubling the logical number of regions and interleaving K and V block addresses during descriptor registration, allowing for separate indexing of K and V components while retaining FlashInfer's memory registration advantages.
- Dynamic Block Length Calculation: Introduced a new helper method,
get_backend_aware_block_len
, to dynamically determine the effective block length based on the attention backend. For FlashInfer, this function returns half theblock_len
to facilitate separate K and V indexing. - Updated Descriptor Registration Logic: Modified the
register_kv_caches
andadd_remote_agent
methods to correctly register and retrieve block descriptors, accounting for the new virtual K/V splitting and adjusted block lengths specific to FlashInfer.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
Signed-off-by: NickLucche <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR aims to add support for FlashInfer in a heterogeneous Tensor Parallelism setting within the NixlConnector
. The core idea is to create a "virtual" split for K/V caches to align with FlashAttention's behavior, which is a clever approach.
The changes look mostly correct and well-reasoned. I've identified a couple of critical issues related to descriptor creation logic that could lead to incorrect data transfers, and a minor issue with a data type change.
- Descriptor Interleaving: The logic for creating
blocks_data
for FlashInfer in bothregister_kv_caches
andadd_remote_agent
seems to produce a non-interleaved list of K and V blocks. This contradicts the logic in_get_block_descs_ids
, which expects an interleaved layout. This needs to be fixed to ensure correct KV cache transfers. - Data Type: There's a change from integer division to float division for
remote_block_size
, which could cause issues. This variable also appears to be unused.
Once these points are addressed, the PR should be in good shape. Great work on tackling this complex integration!
if self._use_flashinfer: | ||
# To maintain the same descs ordering, K/V must be interleaved. | ||
for block_id in range(self.num_blocks): | ||
block_offset = block_id * self.block_len | ||
addr = base_addr + block_offset | ||
# Register addresses for V cache. | ||
v_addr = addr + block_len | ||
blocks_data.append((v_addr, block_len, self.tp_rank)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for creating blocks_data
for FlashInfer seems to be incorrect. The preceding loop (outside this diff) adds all K blocks, and this new block adds all V blocks. This results in a [K0..Kn, V0..Vn]
layout for each layer's descriptors.
However, the existing logic in _get_block_descs_ids
for FlashInfer expects an interleaved layout of [K0, V0, K1, V1, ...]
. This mismatch will likely cause incorrect data to be transferred.
To fix this, the descriptor creation for K and V should be interleaved within a single loop. You'll need to modify the loop at line 631 to handle both FlashInfer and other backends correctly.
For example (conceptual):
if self._use_flashinfer:
# Interleave K and V block registrations
for block_id in range(self.num_blocks):
# ... calculate addr for K
blocks_data.append(...) # K
# ... calculate addr for V
blocks_data.append(...) # V
else:
# Original logic for other backends
for block_id in range(self.num_blocks):
# ...
blocks_data.append(...)
if self._use_flashinfer: | ||
# When FlashInfer is used, index the Vs separately. | ||
for block_id in range(nixl_agent_meta.num_blocks): | ||
block_offset = block_id * nixl_agent_meta.block_len | ||
addr = base_addr + block_offset + rank_offset | ||
v_addr = addr + nixl_agent_meta.block_len // 2 | ||
blocks_data.append((v_addr, block_len, remote_tp_rank)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the issue in register_kv_caches
, the logic here for creating remote block descriptors for FlashInfer seems incorrect. The preceding loop adds K blocks, and this new block adds V blocks, resulting in a non-interleaved [K0..Kn, V0..Vn]
layout.
This contradicts the interleaved layout [K0, V0, K1, V1, ...]
expected by _get_block_descs_ids
for FlashInfer, which will likely lead to incorrect data transfers.
The fix would be to interleave the K and V block descriptor creation within a single loop, similar to the suggestion for register_kv_caches
.
@wseaton already tested this on B200, thanks! |
This PR enables the use of FlashInfer in a heterogeneous TP setting when using NixlConnector, particularly important for Blackwell systems since they will default to FlashInfer.
The main difference from FA is that the cache layout goes from
[2, num_blocks, HND]
to[num_blocks, 2, HND]
where 2 is K/V.With homogeneous TP, this layout change has no particular implication: quite the contrary, we can actually read both K and V in a single message (of size doubled).
In heterogeneous TP, we need to read a portion of heads (
tp_ratio
, eg half the heads) and we can do that efficiently with FA leveraging the HND layout, as we can just say eg "read cache[:2, :, :]" for worker 1 and "read cache[2:, :, :]" for worker 2, indexing on H.With FlashInfer, we have K and V which are not interleaved in memory, as the dim "2" is now right before HND.
Attempting to read eg "half" the kv cache will result in reading all Ks, rather than half the heads for both K and V.
To address that, this PR will add a virtual split so that when flashinfer is detected K/V will be alternated just like FA when creating descriptors. This allows us to use the same logic as before for getting block_ids, while at the same time retaining the memory_registration advantage, effectively registering just num_layers region in NIXL, down from 2*num_layers for FA.
TL;DR: K/V must be alternated when reading to maintain consistency with FA. The number of regions actually registered is half of that of FA, but the number of descs is the same, so at the logical level you won't notice a difference.
Test with