Skip to content

Conversation

ilmarkov
Copy link
Contributor

@ilmarkov ilmarkov commented Sep 4, 2025

First part of spliting #22086

Purpose

Adds support of matching to not only custom ops. So now in order to enable allreduce fusion users don't have to enable rms norm and quant_fp8 custom ops. Also, it allows to keep torch native operations in non-fused parts which are known to be more performant than custom ops alternatives.

Adds a benchmark for allreduce fusion to determine input size thresholds for flashinfer allreduce.
Updates thresholds for flashinfer allreduce (as well as adding two shot algorithm usage when it has better performance) on Hopper and Blackwell devices

Moves allreduce out of moe_forward custom op in order to be able to match for fusion for moe models.

Test Plan

Added tests for non custom ops fusion

Copy link

mergify bot commented Sep 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant enhancement to the all-reduce fusion capabilities, adding support for matching native PyTorch operations in addition to custom ops. This greatly improves usability and performance flexibility. The introduction of a comprehensive benchmark for tuning fusion thresholds is also a valuable addition. The changes are extensive, particularly with the large number of new fusion patterns in vllm/compilation/collective_fusion.py. While the overall approach is sound, I've identified several critical issues in the implementation of these new patterns. Specifically, the return values from some pattern and replacement functions appear to be incorrect, which could lead to fusion failures or incorrect model outputs. I've provided detailed comments and suggestions for these issues. The configuration updates and the new benchmark script are well-implemented and welcome improvements.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return values from the replacement function are incorrect. The pattern returns (rms_output, allreduce_output), which correspond to the normalized output and the all-reduced tensor. The replacement function should return the same structure.

auto_functionalized(flashinfer_trtllm_fused_allreduce_norm, ...) returns a tuple of 5 mutated arguments: (allreduce_in, residual, norm_out, quant_out, scale_out).

The rms_result corresponds to norm_out, which is allreduce[2].
The allreduce_in (which is input to the replacement function) corresponds to allreduce[0].

Therefore, the return statement should be return allreduce[2], allreduce[0].

The current code returns allreduce[3], allreduce[1], which corresponds to (quant_out, residual). This is incorrect and will lead to fusion failures or wrong results.

Suggested change
return allreduce[3], allreduce[1]
return allreduce[2], allreduce[0]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return values from the replacement function are incorrect. The pattern returns (rms_output, rms_residual), which are the normalized output and the residual output. The replacement function should return the same structure.

When norm_out=None is passed to flashinfer_trtllm_fused_allreduce_norm, the allreduce_in tensor is used as the output buffer for the normalization result and is mutated. auto_functionalized will return a tuple where the first element (allreduce[0]) is the mutated allreduce_in (i.e., norm_out), and the second element (allreduce[1]) is the mutated residual.

Therefore, the correct return should be return allreduce[0], allreduce[1].

The current code returns allreduce[1], allreduce[2], which corresponds to (residual, norm_out). Since norm_out is None in the call, this is incorrect.

Suggested change
return allreduce[1], allreduce[2]
return allreduce[0], allreduce[1]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why is the threshold still so low for TP8? I think AR+Norm should have pretty good perf up to some larger message sizes for TP8?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if this can be moved to a util file (like native_op_patterns.py or something like that) so that it can be reused by other fusions like RMSNorm+Q fusions. @ProExpertProg what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could. I think @ProExpertProg will add another approach to support this kind of matching in #24188

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it 1MB for TP8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvpohanh Here are the results for TP=8 Blackwell with torch symm mem (VLLM_ALLREDUCE_USE_SYMM_MEM=1) enabled (see the set of results below). I used the best performant alternative to fused allreduce. Probably, we can condition on it checking if symm mem is available and enabled, it will overcomplicate the configuration, in my opinion. Compared default allreduce flashinfer fused alternative is not significantly better in 4-16MB region (see results in the end)

Symm mem enabled

World Size: 8
Hidden Dimension: 8192
Warmup Iterations: 5
Benchmark Trials: 20
Quantization Mode: none


Configuration: seq_len=32, dtype=bfloat16, no residual

Input Size: 0.50 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.029 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.030 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.012 2.39x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.086 0.34x

Configuration: seq_len=64, dtype=bfloat16, no residual

Input Size: 1.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.030 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.030 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.018 1.62x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.056 0.54x

Configuration: seq_len=128, dtype=bfloat16, no residual

Input Size: 2.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.023 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.024 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.033 0.71x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.052 0.45x

Configuration: seq_len=256, dtype=bfloat16, no residual

Input Size: 4.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.031 0.97x
Standard Allreduce Rmsnorm Native Compiled 0.030 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.064 0.47x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.050 0.60x

Configuration: seq_len=256, dtype=bfloat16, no residual

Input Size: 4.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.031 0.97x
Standard Allreduce Rmsnorm Native Compiled 0.030 baseline
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.049 0.61x

Configuration: seq_len=512, dtype=bfloat16, no residual

Input Size: 8.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.044 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.043 baseline
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.297 0.15x

Configuration: seq_len=1024, dtype=bfloat16, no residual

Input Size: 16.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.071 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.077 0.93x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.109 0.66x

Configuration: seq_len=2048, dtype=bfloat16, no residual

Input Size: 32.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.135 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.143 0.94x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.205 0.66x

Default allreduce

Configuration: seq_len=32, dtype=bfloat16, no residual

Input Size: 0.50 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.029 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.030 0.99x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.012 2.44x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.087 0.34x

Configuration: seq_len=64, dtype=bfloat16, no residual

Input Size: 1.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.030 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.030 1.00x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.019 1.63x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.056 0.54x

Configuration: seq_len=128, dtype=bfloat16, no residual

Input Size: 2.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.032 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.032 1.00x
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.033 0.97x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.052 0.62x

Configuration: seq_len=256, dtype=bfloat16, no residual

Input Size: 4.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.051 0.98x
Standard Allreduce Rmsnorm Native Compiled 0.050 baseline
Flashinfer Fused Allreduce Rmsnorm Oneshot 0.064 0.77x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.050 1.00x

Configuration: seq_len=512, dtype=bfloat16, no residual

Input Size: 8.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.079 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.081 0.97x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.068 1.17x

Configuration: seq_len=1024, dtype=bfloat16, no residual

Input Size: 16.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.119 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.125 0.95x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.109 1.09x

Configuration: seq_len=2048, dtype=bfloat16, no residual

Input Size: 32.00 MB

Operation Time (ms) Speedup
Standard Allreduce Rmsnorm 0.195 1.00x
Standard Allreduce Rmsnorm Native Compiled 0.211 0.93x
Flashinfer Fused Allreduce Rmsnorm Twoshot 0.204 0.96x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilmarkov Is VLLM_ALLREDUCE_USE_SYMM_MEM=1 something that normal vLLM users would set by default? If it's good for performance, why can't we enable it by default? Does it require special environment or special builds? cc @ProExpertProg

@nvjullin Could you check if @ilmarkov 's measurements above match our understanding? Also, could you try if VLLM_ALLREDUCE_USE_SYMM_MEM=1 works in our case? Thanks!

Copy link
Contributor Author

@ilmarkov ilmarkov Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it can be enabled by default. There is a PR for it. It works on Hopper and Blackwell.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! we will try both your PRs and run some experiments on our side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilmarkov Just to clarify: the PyTorch SYMM_MEM implementation does not support AR+Norm fusion, right? So only the AR part uses SYMM_MEM while Norm part is based on native PyT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, symm mem is only for allreduce part, Norm and quantization parts are in native pytorch.

@nvpohanh
Copy link
Contributor

nvpohanh commented Sep 5, 2025

cc @nvjullin @elvischenv for vis

@ilmarkov ilmarkov force-pushed the imarkov/fused_allreduce_torch_native branch from e808818 to 61ebc95 Compare September 8, 2025 12:02
@mergify mergify bot removed the needs-rebase label Sep 8, 2025
Copy link

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ilmarkov.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
input = input.to(torch.float32)
if residual is not None:
input = input + residual.to(torch.float32)
# residual = input.to(orig_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just found the issue with this, turns out Inductor eliminates this cast because residual is casted right back to float32 so it never needs to be down-converted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-rebase performance Performance-related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants