Skip to content

Conversation

ilmarkov
Copy link
Contributor

@ilmarkov ilmarkov commented Aug 1, 2025

Pull Request Description

This pull request introduces several key enhancements for improving the flexibility and performance of vllm, particularly for allreduce operations and dynamic graph compilation.

Summary of Changes

  • Allreduce Fusion without Custom Ops: Adds support for allreduce fusion, even when custom operations are not enabled.

  • Thresholds Adjustment: Adjusts the thresholds for allreduce fusion to optimize performance across various configurations.

  • Dynamic Graph Dispatch via compile_ranges: Introduces a new configuration option, compile_ranges, as an alternative to compile_sizes. This enables dynamic dispatch to different compiled graphs based on the input batch size.
    Now with this approach, when allreduce fusion is enabled, vllm adds additional compile range split point in order to separate the graphs:

  1. One with fused allreduce for small-middle shape inputs
    2 One with nccl based allreduce for large shape inputs

Detailed Breakdown

1. Allreduce Fusion without Custom Operations

This change extends the allreduce fusion matching work without custom ops, supporting torch native rmsnorm and quant_fp8 implementations. It allows to keep more performant non fused operations in the graph.

2. Dynamic Graph Dispatch with compile_ranges

The existing compile_sizes feature is extended and generalized with compile_ranges. Defined by split points, these ranges allow vllm to dynamically dispatch requests to specific, pre-compiled graphs based on input batch size. For example, a configuration of (32, 64) defines three distinct ranges: [1, 32), [32, 64), and [64, max_num_batched_tokens). This provides granular control, allowing developers to statically enable or disable fusions within each graph to optimize performance for different batch sizes.

Motivation

Corresponding RFC: #23113
The primary motivation for these changes is to enhance vllm's performance and adaptability for diverse workloads. By supporting allreduce fusion without custom ops and introducing dynamic graph dispatch, we empower users to fine-tune vllm for more efficient and scalable inference.

Update

The PR will be split up following @ProExpertProg recommendation.
First part is torch native ops fusion support: #24248
Second part for adding conditional compilation based on ranges: #24252

Copy link

github-actions bot commented Aug 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides several follow-up changes for Allreduce fusion, including cleaning up the fusion pass, moving the all-reduce operation out of the fused_moe custom op, and disabling a test for FP4 fusion. The refactoring in fused_moe/layer.py is a good improvement for modularity. However, I've identified a critical issue in vllm/compilation/collective_fusion.py where the fallback path for FP8 quantization appears to have been unintentionally removed. This needs to be addressed to ensure correctness.

Comment on lines 491 to 486
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
scale_out, scale_factor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

It appears the logic for FP8 quantization in the fallback path has been unintentionally removed. The original code correctly handled both FP4 and FP8 quantization by checking if scale_out is None. The new code unconditionally calls torch.ops._C.scaled_fp4_quant, which will likely fail at runtime when scale_out is None during FP8 quantization.

This change breaks FP8 quantization in the non-fused path. Please restore the conditional logic to support both FP4 and FP8 quantization.

Suggested change
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
scale_out, scale_factor)
if scale_out is not None:
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
scale_out, scale_factor)
else:
torch.ops._C.static_scaled_fp8_quant(
quant_out, norm_out, scale_factor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only do fused fp8 logic. We don't go into this branch in case of fp8 quantization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fp8 path is intentionally removed in this if-else branch

Copy link

mergify bot commented Aug 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 1, 2025
@ilmarkov ilmarkov force-pushed the imarkov/flashinfer_allreduce_fusion_follow_ups branch from 903664f to 1ff71f7 Compare August 11, 2025 09:31
@mergify mergify bot added performance Performance-related issues and removed needs-rebase labels Aug 11, 2025
@ilmarkov
Copy link
Contributor Author

Results from isolated fused allreduce benchmark.

torchrun --nproc_per_node=$n benchmarks/kernels/benchmark_fused_collective.py [--no-quant|--quant-fp8|--quant-fp4] --hidden-dim 8192 --seq-len 32 64 96 128 160 192 256 --trials 100

H100 TP=2 No quant

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.010 0.96x
Standard Allreduce Rmsnorm Native Compiled 0.010 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.010 1.02x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.079 0.13x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.012 baseline
Standard Allreduce Rmsnorm Native Compiled 0.012 0.98x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.016 0.79x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.085 0.14x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.017 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.017 0.98x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.030 0.55x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.099 0.17x

H100 TP=4 no quant

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.017 0.97x
Standard Allreduce Rmsnorm Native Compiled 0.016 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.013 1.27x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.084 0.20x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.019 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.019 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.023 0.83x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.082 0.23x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.024 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.025 0.97x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.044 0.55x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.096 0.25x

H100 TP=2 quant fp8

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.011 0.97x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.011 baseline
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.010 1.08x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.078 0.14x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.013 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.013 baseline
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.016 0.84x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.087 0.15x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.018 0.99x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.017 baseline
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.030 0.58x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.096 0.18x

H100 TP=4 fp8

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.018 0.97x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.017 baseline
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.013 1.31x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.085 0.20x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.020 0.99x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.020 baseline
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.023 0.84x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.082 0.24x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.026 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.026 baseline
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.045 0.58x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.097 0.26x

@ilmarkov
Copy link
Contributor Author

B200 TP=2 no quant

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.014 0.99x
Standard Allreduce Rmsnorm Native Compiled 0.013 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.008 1.76x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.031 0.44x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.018 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.018 0.98x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.012 1.50x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.040 0.44x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.025 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.025 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.020 1.22x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.049 0.51x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.040 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.039 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.037 1.05x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.071 0.55x

B200 TP=4 no quant

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.022 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.022 0.98x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.011 2.08x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.044 0.51x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.022 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.023 0.98x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.017 1.30x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.038 0.58x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.031 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.032 0.98x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.032 0.97x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.049 0.64x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.050 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.049 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.060 0.82x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.069 0.71x

B200 TP=8 no quant

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.030 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.030 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.014 2.22x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.086 0.35x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.031 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.031 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.025 1.25x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.054 0.57x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.033 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.033 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.047 0.69x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.052 0.63x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.052 0.99x
Standard Allreduce Rmsnorm Native Compiled 0.051 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.093 0.55x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.071 0.72x

B200 TP=2 fp8

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.014 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.014 1.00x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.008 1.79x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.032 0.45x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.018 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.018 0.98x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.012 1.54x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.039 0.46x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.026 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.026 1.00x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.021 1.24x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.048 0.53x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.041 0.97x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.040 baseline
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.038 1.06x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.071 0.56x

B200 TP=4 fp8

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.023 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.023 0.98x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.010 2.25x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.134 0.17x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.023 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.141 0.16x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.016 1.46x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.039 0.59x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.032 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.033 0.98x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.030 1.06x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.050 0.64x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.051 0.98x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.050 baseline
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.055 0.91x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.070 0.71x

B200 TP=8 fp8

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.031 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.031 0.98x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.014 2.22x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.085 0.36x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.032 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.032 0.98x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.025 1.28x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.056 0.56x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.033 1.00x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.034 0.98x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.047 0.69x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.053 0.62x

Configuration: seq_len=256, dtype=bfloat16, no residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp8 Quant 0.054 0.97x
Standard Allreduce Rmsnorm Fp8 Quant Native Compiled 0.052 baseline
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Oneshot 0.095 0.55x
Flashinfer Fused Allreduce Rmsnorm Fp8 Quant Twoshot 0.081 0.64x

B200 TP=2 fp4

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.016 0.99x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.016 baseline
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.009 1.72x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.033 0.48x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.020 1.00x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.020 0.99x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.014 1.47x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.041 0.49x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.028 1.00x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.028 0.99x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.025 1.09x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.051 0.54x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.043 0.97x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.042 baseline
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.046 0.91x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.075 0.55x

B200 TP=4 fp4

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.024 1.00x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.024 1.00x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.011 2.22x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.043 0.57x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.025 1.00x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.025 0.99x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.018 1.40x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.039 0.63x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.034 1.00x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.034 0.99x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.034 1.01x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.051 0.66x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.053 0.98x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.052 baseline
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.063 0.82x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.071 0.72x

B200 TP=8 fp4

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.033 1.00x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.033 baseline
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.014 2.37x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.088 0.37x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.033 1.00x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.033 1.00x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.025 1.35x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.058 0.58x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.035 1.00x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.036 0.98x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.047 0.74x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.054 0.65x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm Fp4 Quant 0.055 0.98x
Standard Allreduce Rmsnorm Fp4 Quant Native Compiled 0.054 baseline
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Oneshot 0.093 0.58x
Flashinfer Fused Allreduce Rmsnorm Fp4 Quant Twoshot 0.075 0.72x

It shows that one shot FlashInfer fusion can give up to 2x speedup on small inputs, whereas two shot is always slower than non-fused alternatives.

@ilmarkov
Copy link
Contributor Author

To confirm that threshold has to be tensor size, not number of tokens here are the benchmark results for hidden size 1024

B200 TP=2 no quant

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.010 0.89x
Standard Allreduce Rmsnorm Native Compiled 0.009 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.005 1.90x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.121 0.08x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.010 0.94x
Standard Allreduce Rmsnorm Native Compiled 0.010 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.005 1.93x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.065 0.15x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.011 0.90x
Standard Allreduce Rmsnorm Native Compiled 0.010 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.005 1.85x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.037 0.26x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.013 0.95x
Standard Allreduce Rmsnorm Native Compiled 0.012 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.006 1.95x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.022 0.55x

Configuration: seq_len=512, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.017 0.97x
Standard Allreduce Rmsnorm Native Compiled 0.016 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.008 2.12x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.027 0.60x

Configuration: seq_len=1024, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.024 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.024 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.011 2.24x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.035 0.69x

Configuration: seq_len=2048, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.040 0.97x
Standard Allreduce Rmsnorm Native Compiled 0.039 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.020 1.89x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.045 0.86x

B200 TP=4 no quant

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.011 0.94x
Standard Allreduce Rmsnorm Native Compiled 0.010 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.005 1.83x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.238 0.04x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.011 0.94x
Standard Allreduce Rmsnorm Native Compiled 0.010 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.006 1.71x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.133 0.08x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.011 0.95x
Standard Allreduce Rmsnorm Native Compiled 0.011 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.006 1.68x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.068 0.16x

Configuration: seq_len=256, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.021 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.020 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.009 2.29x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.040 0.51x

Configuration: seq_len=512, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.022 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.022 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.012 1.73x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.027 0.81x

Configuration: seq_len=1024, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.031 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.032 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.021 1.50x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.035 0.91x

Configuration: seq_len=2048, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.050 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.049 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.039 1.25x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.049 1.00x

B200 TP=8 no quant

Configuration: seq_len=32, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.011 0.95x
Standard Allreduce Rmsnorm Native Compiled 0.011 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.006 1.82x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.512 0.02x

Configuration: seq_len=64, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.012 0.95x
Standard Allreduce Rmsnorm Native Compiled 0.012 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.007 1.72x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.265 0.04x

Configuration: seq_len=128, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.029 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.028 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.008 3.39x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.141 0.20x

Configuration: seq_len=512, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.030 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.030 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.019 1.55x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.048 0.62x

Configuration: seq_len=1024, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.032 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.032 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.036 0.88x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.035 0.91x

Configuration: seq_len=2048, dtype=bfloat16, with residual

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.052 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.050 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.070 0.72x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.052 0.97x

@nvpohanh
Copy link
Contributor

Hi @ilmarkov , thanks for working on this! I am just wondering if you will continue to work on this? If so, do you have a rough estimate when this PR will be ready? Just want to avoid duplicated work from our side. Thanks!

@ilmarkov
Copy link
Contributor Author

Hi @nvpohanh . Yes, I am still working on this. The plan is to enable FlashInfer allreduce fusion by default. At the moment, we can't do it due to performance of fallback (for medium and large input sizes).
We are looking for the ways to use torch.compiled fallback.

Copy link

mergify bot commented Aug 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 19, 2025
@nvpohanh
Copy link
Contributor

@ilmarkov Could you re-tune the thresholds using newer FlashInfer with this fix? flashinfer-ai/flashinfer#1507

We found that FlashInfer uses suboptimal cuda grid/block size when calling the AllReduce+RMSNorm kernel. It can lead to >2x perf boost for some cases we tested with. Thanks!

@ilmarkov ilmarkov force-pushed the imarkov/flashinfer_allreduce_fusion_follow_ups branch from c269312 to 3e4e159 Compare August 20, 2025 14:34
@mergify mergify bot removed the needs-rebase label Aug 20, 2025
@nvjullin
Copy link
Contributor

nvjullin commented Aug 22, 2025

I ran benchmarks/kernels/benchmark_fused_collective.py with hidden_dim 1024/8192 TP 2/4 and seq_len 1/2/4/8/16/32/64/128/256/512 again with flashinfer-ai/flashinfer#1507 and found that on B200

  1. oneshot is always faster than native
  2. for hidden_dim 8192 TP4 seq_len 512, twoshot is faster than oneshot

no-quant-tp2-hd1024.txt
no-quant-tp2-hd8192.txt
no-quant-tp4-hd1024.txt
no-quant-tp4-hd8192.txt
quant-fp4-tp2-hd1024.txt
quant-fp4-tp2-hd8192.txt
quant-fp4-tp4-hd1024.txt
quant-fp4-tp4-hd8192.txt
quant-fp8-tp2-hd1024.txt
quant-fp8-tp2-hd8192.txt
quant-fp8-tp4-hd1024.txt
quant-fp8-tp4-hd8192.txt

@ilmarkov I think we can simply default to flashinfer oneshot now that it's always faster (on B200 at least)?

@ilmarkov
Copy link
Contributor Author

@nvjullin Thank you for the benchmark results. Yes, we default to oneshot. I will update the thresholds accordingly.

Copy link

mergify bot commented Aug 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 26, 2025
@ilmarkov ilmarkov force-pushed the imarkov/flashinfer_allreduce_fusion_follow_ups branch from ad72306 to e68e14f Compare August 26, 2025 09:51
@mergify mergify bot removed the needs-rebase label Aug 26, 2025
Copy link

mergify bot commented Aug 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

ilmarkov and others added 12 commits September 1, 2025 12:56
Signed-off-by: ilmarkov <[email protected]>
Update range based compilation

Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
@ilmarkov ilmarkov force-pushed the imarkov/flashinfer_allreduce_fusion_follow_ups branch from 1cd7acf to ca9f59e Compare September 1, 2025 12:56
@mergify mergify bot removed the needs-rebase label Sep 1, 2025
Signed-off-by: ilmarkov <[email protected]>
@ilmarkov
Copy link
Contributor Author

ilmarkov commented Sep 2, 2025

The failing basic correctness tests and entrypoints are because of torch.standalone_compile issue which was fixed in pytorch/pytorch#157803.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Goo job getting all of this to work! I left a few minor comments. A couple of general things:

  • There are a lot of formatting changes, which make the PR hard to read, could you please remove those?
  • We should split up the compile ranges part and other allreduce fixes. If it's easier, we can do compile ranges without any allreduce changes and then follow up with everything allreduce related. Alternatively we can just put the custom op matching and other stuff orthogonal to compile ranges in a separate PR to land it first, although I'm working on making that easier anyway.
  • TODO: tell Inductor about the compile range (should ask Meta about it)
  • instead of having a single dynamic shape graph, could we just use a single (1, max_num_batched_tokens) graph if no split points are passed? And so compile_range would never be Optional? Maybe that's something for a follow-up.

torch._inductor.pattern_matcher._seen_patterns.clear()

def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
def is_applicable_for_range(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the range is accessible through context anyway, should we just get rid of this method and let each pass exit early inside __call__? That would simplify the interface

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already done inside a __call__ in PostGradPassManager in unified manner, and it's taken from the context. With this approach, we could easier add tracking which passes are enabled in a graph.

else:
self.compile_ranges.append((split_points[i - 1], s))
if s in self.compile_sizes:
self.compile_ranges.append((s, s))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If 1 is in compile sizes doesn't it get appended 2x here? It might be good to extract this conversion to a utility method we can test separately.

if isinstance(compile_range, tuple):
# for a specific range of batchsizes, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO(luka): max autotune only present with -O3, and this should live in config: https://github.com/vllm-project/vllm/issues/20283

to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""
compile_ranges_split_points: Optional[list[int]] = None
"""Split points that represent compile ranges for inductor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention that compile_sizes also play a role in splitting

GroupShape)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils import (_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the FI stuff should live in utils

return [input, weight]

def register(self, pm_pass: PatternMatcherPass):
if not self.is_custom_rms_norm:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like these two should just be separated into separate pattern classes

Copy link

mergify bot commented Sep 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 3, 2025
Signed-off-by: ilmarkov <[email protected]>
@ilmarkov ilmarkov force-pushed the imarkov/flashinfer_allreduce_fusion_follow_ups branch from 0a8c72b to 82057db Compare September 3, 2025 14:02
@ProExpertProg ProExpertProg moved this from To triage to In progress in torch.compile integration Sep 3, 2025
@ilmarkov
Copy link
Contributor Author

ilmarkov commented Sep 4, 2025

The PR will be split up following @ProExpertProg recommendation.
First part is torch native ops fusion support: #24248

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

4 participants