-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds #24248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds #24248
Conversation
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant enhancement to the all-reduce fusion capabilities, adding support for matching native PyTorch operations in addition to custom ops. This greatly improves usability and performance flexibility. The introduction of a comprehensive benchmark for tuning fusion thresholds is also a valuable addition. The changes are extensive, particularly with the large number of new fusion patterns in vllm/compilation/collective_fusion.py
. While the overall approach is sound, I've identified several critical issues in the implementation of these new patterns. Specifically, the return values from some pattern
and replacement
functions appear to be incorrect, which could lead to fusion failures or incorrect model outputs. I've provided detailed comments and suggestions for these issues. The configuration updates and the new benchmark script are well-implemented and welcome improvements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return values from the replacement
function are incorrect. The pattern
returns (rms_output, allreduce_output)
, which correspond to the normalized output and the all-reduced tensor. The replacement
function should return the same structure.
auto_functionalized(flashinfer_trtllm_fused_allreduce_norm, ...)
returns a tuple of 5 mutated arguments: (allreduce_in, residual, norm_out, quant_out, scale_out)
.
The rms_result
corresponds to norm_out
, which is allreduce[2]
.
The allreduce_in
(which is input
to the replacement function) corresponds to allreduce[0]
.
Therefore, the return statement should be return allreduce[2], allreduce[0]
.
The current code returns allreduce[3], allreduce[1]
, which corresponds to (quant_out, residual)
. This is incorrect and will lead to fusion failures or wrong results.
return allreduce[3], allreduce[1] | |
return allreduce[2], allreduce[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return values from the replacement
function are incorrect. The pattern
returns (rms_output, rms_residual)
, which are the normalized output and the residual output. The replacement
function should return the same structure.
When norm_out=None
is passed to flashinfer_trtllm_fused_allreduce_norm
, the allreduce_in
tensor is used as the output buffer for the normalization result and is mutated. auto_functionalized
will return a tuple where the first element (allreduce[0]
) is the mutated allreduce_in
(i.e., norm_out
), and the second element (allreduce[1]
) is the mutated residual
.
Therefore, the correct return should be return allreduce[0], allreduce[1]
.
The current code returns allreduce[1], allreduce[2]
, which corresponds to (residual, norm_out)
. Since norm_out
is None
in the call, this is incorrect.
return allreduce[1], allreduce[2] | |
return allreduce[0], allreduce[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: why is the threshold still so low for TP8? I think AR+Norm should have pretty good perf up to some larger message sizes for TP8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if this can be moved to a util file (like native_op_patterns.py
or something like that) so that it can be reused by other fusions like RMSNorm+Q fusions. @ProExpertProg what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could. I think @ProExpertProg will add another approach to support this kind of matching in #24188
vllm/config/compilation.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it 1MB for TP8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nvpohanh Here are the results for TP=8 Blackwell with torch symm mem (VLLM_ALLREDUCE_USE_SYMM_MEM=1) enabled (see the set of results below). I used the best performant alternative to fused allreduce. Probably, we can condition on it checking if symm mem is available and enabled, it will overcomplicate the configuration, in my opinion. Compared default allreduce flashinfer fused alternative is not significantly better in 4-16MB region (see results in the end)
Symm mem enabled
World Size: 8
Hidden Dimension: 8192
Warmup Iterations: 5
Benchmark Trials: 20
Quantization Mode: none
Configuration: seq_len=32, dtype=bfloat16, no residual
Input Size: 0.50 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.029 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.030 | 0.99x |
Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.012 | 2.39x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.086 | 0.34x |
Configuration: seq_len=64, dtype=bfloat16, no residual
Input Size: 1.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.030 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.030 | 0.99x |
Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.018 | 1.62x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.056 | 0.54x |
Configuration: seq_len=128, dtype=bfloat16, no residual
Input Size: 2.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.023 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.024 | 0.99x |
Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.033 | 0.71x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.052 | 0.45x |
Configuration: seq_len=256, dtype=bfloat16, no residual
Input Size: 4.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.031 | 0.97x |
Standard Allreduce Rmsnorm Native Compiled | 0.030 | baseline |
Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.064 | 0.47x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.050 | 0.60x |
Configuration: seq_len=256, dtype=bfloat16, no residual
Input Size: 4.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.031 | 0.97x |
Standard Allreduce Rmsnorm Native Compiled | 0.030 | baseline |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.049 | 0.61x |
Configuration: seq_len=512, dtype=bfloat16, no residual
Input Size: 8.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.044 | 0.98x |
Standard Allreduce Rmsnorm Native Compiled | 0.043 | baseline |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.297 | 0.15x |
Configuration: seq_len=1024, dtype=bfloat16, no residual
Input Size: 16.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.071 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.077 | 0.93x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.109 | 0.66x |
Configuration: seq_len=2048, dtype=bfloat16, no residual
Input Size: 32.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.135 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.143 | 0.94x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.205 | 0.66x |
Default allreduce
Configuration: seq_len=32, dtype=bfloat16, no residual
Input Size: 0.50 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.029 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.030 | 0.99x |
Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.012 | 2.44x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.087 | 0.34x |
Configuration: seq_len=64, dtype=bfloat16, no residual
Input Size: 1.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.030 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.030 | 1.00x |
Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.019 | 1.63x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.056 | 0.54x |
Configuration: seq_len=128, dtype=bfloat16, no residual
Input Size: 2.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.032 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.032 | 1.00x |
Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.033 | 0.97x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.052 | 0.62x |
Configuration: seq_len=256, dtype=bfloat16, no residual
Input Size: 4.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.051 | 0.98x |
Standard Allreduce Rmsnorm Native Compiled | 0.050 | baseline |
Flashinfer Fused Allreduce Rmsnorm Oneshot | 0.064 | 0.77x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.050 | 1.00x |
Configuration: seq_len=512, dtype=bfloat16, no residual
Input Size: 8.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.079 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.081 | 0.97x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.068 | 1.17x |
Configuration: seq_len=1024, dtype=bfloat16, no residual
Input Size: 16.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.119 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.125 | 0.95x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.109 | 1.09x |
Configuration: seq_len=2048, dtype=bfloat16, no residual
Input Size: 32.00 MB
Operation | Time (ms) | Speedup |
---|---|---|
Standard Allreduce Rmsnorm | 0.195 | 1.00x |
Standard Allreduce Rmsnorm Native Compiled | 0.211 | 0.93x |
Flashinfer Fused Allreduce Rmsnorm Twoshot | 0.204 | 0.96x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ilmarkov Is VLLM_ALLREDUCE_USE_SYMM_MEM=1
something that normal vLLM users would set by default? If it's good for performance, why can't we enable it by default? Does it require special environment or special builds? cc @ProExpertProg
@nvjullin Could you check if @ilmarkov 's measurements above match our understanding? Also, could you try if VLLM_ALLREDUCE_USE_SYMM_MEM=1
works in our case? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it can be enabled by default. There is a PR for it. It works on Hopper and Blackwell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it! we will try both your PRs and run some experiments on our side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ilmarkov Just to clarify: the PyTorch SYMM_MEM implementation does not support AR+Norm fusion, right? So only the AR part uses SYMM_MEM while Norm part is based on native PyT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, symm mem is only for allreduce part, Norm and quantization parts are in native pytorch.
cc @nvjullin @elvischenv for vis |
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
e808818
to
61ebc95
Compare
This pull request has merge conflicts that must be resolved before it can be |
input = input.to(torch.float32) | ||
if residual is not None: | ||
input = input + residual.to(torch.float32) | ||
# residual = input.to(orig_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just found the issue with this, turns out Inductor eliminates this cast because residual is casted right back to float32 so it never needs to be down-converted.
First part of spliting #22086
Purpose
Adds support of matching to not only custom ops. So now in order to enable allreduce fusion users don't have to enable rms norm and quant_fp8 custom ops. Also, it allows to keep torch native operations in non-fused parts which are known to be more performant than custom ops alternatives.
Adds a benchmark for allreduce fusion to determine input size thresholds for flashinfer allreduce.
Updates thresholds for flashinfer allreduce (as well as adding two shot algorithm usage when it has better performance) on Hopper and Blackwell devices
Moves allreduce out of moe_forward custom op in order to be able to match for fusion for moe models.
Test Plan
Added tests for non custom ops fusion