Skip to content

[Nixl] Heterogeneous TP support FlashInfer #20189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Jun 27, 2025

This PR enables the use of FlashInfer in a heterogeneous TP setting when using NixlConnector, particularly important for Blackwell systems since they will default to FlashInfer.

The main difference from FA is that the cache layout goes from [2, num_blocks, HND] to [num_blocks, 2, HND] where 2 is K/V.
With homogeneous TP, this layout change has no particular implication: quite the contrary, we can actually read both K and V in a single message (of size doubled).

image

In heterogeneous TP, we need to read a portion of heads (tp_ratio, eg half the heads) and we can do that efficiently with FA leveraging the HND layout, as we can just say eg "read cache[:2, :, :]" for worker 1 and "read cache[2:, :, :]" for worker 2, indexing on H.

With FlashInfer, we have K and V which are not interleaved in memory, as the dim "2" is now right before HND.
Attempting to read eg "half" the kv cache will result in reading all Ks, rather than half the heads for both K and V.

To address that, this PR will add a virtual split so that when flashinfer is detected K/V will be alternated just like FA when creating descriptors. This allows us to use the same logic as before for getting block_ids, while at the same time retaining the memory_registration advantage, effectively registering just num_layers region in NIXL, down from 2*num_layers for FA.

TL;DR: K/V must be alternated when reading to maintain consistency with FA. The number of regions actually registered is half of that of FA, but the number of descs is the same, so at the logical level you won't notice a difference.

Test with

# Flashinfer run
VLLM_ATTENTION_BACKEND=FLASHINFER NUM_DECODE_INSTANCES=1 DECODER_TP_SIZE=2 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
./script.sh  1429.72s user 106.44s system 580% cpu 4:24.47 total

# FlashAttention run
NUM_DECODE_INSTANCES=1 DECODER_TP_SIZE=2 bash tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
 ./script.sh  1450.54s user 107.23s system 584% cpu 4:26.68 total

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @NickLucche, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for FlashInfer in a heterogeneous Tensor Parallel (TP) environment when utilizing the NixlConnector. The primary goal is to reconcile the differing KV cache memory layouts between FlashInfer and FlashAttention, enabling efficient and consistent block-level data transfer for distributed inference by introducing a virtual splitting mechanism for FlashInfer's KV cache.

Highlights

  • FlashInfer Heterogeneous TP Support: Enabled FlashInfer to work efficiently in heterogeneous Tensor Parallel (TP) settings by addressing its distinct KV cache memory layout compared to FlashAttention. This ensures consistent block-level data transfer for distributed inference.
  • Virtual KV Cache Splitting: Implemented a 'virtual split' mechanism for FlashInfer's KV cache. This involves doubling the logical number of regions and interleaving K and V block addresses during descriptor registration, allowing for separate indexing of K and V components while retaining FlashInfer's memory registration advantages.
  • Dynamic Block Length Calculation: Introduced a new helper method, get_backend_aware_block_len, to dynamically determine the effective block length based on the attention backend. For FlashInfer, this function returns half the block_len to facilitate separate K and V indexing.
  • Updated Descriptor Registration Logic: Modified the register_kv_caches and add_remote_agent methods to correctly register and retrieve block descriptors, accounting for the new virtual K/V splitting and adjusted block lengths specific to FlashInfer.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Signed-off-by: NickLucche <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR aims to add support for FlashInfer in a heterogeneous Tensor Parallelism setting within the NixlConnector. The core idea is to create a "virtual" split for K/V caches to align with FlashAttention's behavior, which is a clever approach.

The changes look mostly correct and well-reasoned. I've identified a couple of critical issues related to descriptor creation logic that could lead to incorrect data transfers, and a minor issue with a data type change.

  • Descriptor Interleaving: The logic for creating blocks_data for FlashInfer in both register_kv_caches and add_remote_agent seems to produce a non-interleaved list of K and V blocks. This contradicts the logic in _get_block_descs_ids, which expects an interleaved layout. This needs to be fixed to ensure correct KV cache transfers.
  • Data Type: There's a change from integer division to float division for remote_block_size, which could cause issues. This variable also appears to be unused.

Once these points are addressed, the PR should be in good shape. Great work on tackling this complex integration!

Comment on lines +639 to +646
if self._use_flashinfer:
# To maintain the same descs ordering, K/V must be interleaved.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# Register addresses for V cache.
v_addr = addr + block_len
blocks_data.append((v_addr, block_len, self.tp_rank))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for creating blocks_data for FlashInfer seems to be incorrect. The preceding loop (outside this diff) adds all K blocks, and this new block adds all V blocks. This results in a [K0..Kn, V0..Vn] layout for each layer's descriptors.

However, the existing logic in _get_block_descs_ids for FlashInfer expects an interleaved layout of [K0, V0, K1, V1, ...]. This mismatch will likely cause incorrect data to be transferred.

To fix this, the descriptor creation for K and V should be interleaved within a single loop. You'll need to modify the loop at line 631 to handle both FlashInfer and other backends correctly.

For example (conceptual):

if self._use_flashinfer:
    # Interleave K and V block registrations
    for block_id in range(self.num_blocks):
        # ... calculate addr for K
        blocks_data.append(...) # K
        # ... calculate addr for V
        blocks_data.append(...) # V
else:
    # Original logic for other backends
    for block_id in range(self.num_blocks):
        # ...
        blocks_data.append(...)

Comment on lines +789 to +795
if self._use_flashinfer:
# When FlashInfer is used, index the Vs separately.
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_len // 2
blocks_data.append((v_addr, block_len, remote_tp_rank))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the issue in register_kv_caches, the logic here for creating remote block descriptors for FlashInfer seems incorrect. The preceding loop adds K blocks, and this new block adds V blocks, resulting in a non-interleaved [K0..Kn, V0..Vn] layout.

This contradicts the interleaved layout [K0, V0, K1, V1, ...] expected by _get_block_descs_ids for FlashInfer, which will likely lead to incorrect data transfers.

The fix would be to interleave the K and V block descriptor creation within a single loop, similar to the suggestion for register_kv_caches.

@NickLucche
Copy link
Contributor Author

@wseaton already tested this on B200, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant