-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[PERF] Allreduce Fusion tuning and compile_ranges introduction #22086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PERF] Allreduce Fusion tuning and compile_ranges introduction #22086
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request provides several follow-up changes for Allreduce fusion, including cleaning up the fusion pass, moving the all-reduce operation out of the fused_moe
custom op, and disabling a test for FP4 fusion. The refactoring in fused_moe/layer.py
is a good improvement for modularity. However, I've identified a critical issue in vllm/compilation/collective_fusion.py
where the fallback path for FP8 quantization appears to have been unintentionally removed. This needs to be addressed to ensure correctness.
torch.ops._C.scaled_fp4_quant(quant_out, norm_out, | ||
scale_out, scale_factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears the logic for FP8 quantization in the fallback path has been unintentionally removed. The original code correctly handled both FP4 and FP8 quantization by checking if scale_out
is None
. The new code unconditionally calls torch.ops._C.scaled_fp4_quant
, which will likely fail at runtime when scale_out
is None
during FP8 quantization.
This change breaks FP8 quantization in the non-fused path. Please restore the conditional logic to support both FP4 and FP8 quantization.
torch.ops._C.scaled_fp4_quant(quant_out, norm_out, | |
scale_out, scale_factor) | |
if scale_out is not None: | |
torch.ops._C.scaled_fp4_quant(quant_out, norm_out, | |
scale_out, scale_factor) | |
else: | |
torch.ops._C.static_scaled_fp8_quant( | |
quant_out, norm_out, scale_factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only do fused fp8 logic. We don't go into this branch in case of fp8 quantization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fp8 path is intentionally removed in this if-else branch
This pull request has merge conflicts that must be resolved before it can be |
903664f
to
1ff71f7
Compare
Results from isolated fused allreduce benchmark.
H100 TP=2 No quantConfiguration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
H100 TP=4 no quantConfiguration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
H100 TP=2 quant fp8Configuration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
H100 TP=4 fp8Configuration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
|
B200 TP=2 no quantConfiguration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
B200 TP=4 no quantConfiguration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
B200 TP=8 no quantConfiguration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
B200 TP=2 fp8Configuration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
B200 TP=4 fp8Configuration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
B200 TP=8 fp8Configuration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, no residual
B200 TP=2 fp4Configuration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
B200 TP=4 fp4Configuration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
B200 TP=8 fp4Configuration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
It shows that one shot FlashInfer fusion can give up to 2x speedup on small inputs, whereas two shot is always slower than non-fused alternatives. |
To confirm that threshold has to be tensor size, not number of tokens here are the benchmark results for hidden size 1024 B200 TP=2 no quantConfiguration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
Configuration: seq_len=512, dtype=bfloat16, with residual
Configuration: seq_len=1024, dtype=bfloat16, with residual
Configuration: seq_len=2048, dtype=bfloat16, with residual
B200 TP=4 no quantConfiguration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=256, dtype=bfloat16, with residual
Configuration: seq_len=512, dtype=bfloat16, with residual
Configuration: seq_len=1024, dtype=bfloat16, with residual
Configuration: seq_len=2048, dtype=bfloat16, with residual
B200 TP=8 no quantConfiguration: seq_len=32, dtype=bfloat16, with residual
Configuration: seq_len=64, dtype=bfloat16, with residual
Configuration: seq_len=128, dtype=bfloat16, with residual
Configuration: seq_len=512, dtype=bfloat16, with residual
Configuration: seq_len=1024, dtype=bfloat16, with residual
Configuration: seq_len=2048, dtype=bfloat16, with residual
|
Hi @ilmarkov , thanks for working on this! I am just wondering if you will continue to work on this? If so, do you have a rough estimate when this PR will be ready? Just want to avoid duplicated work from our side. Thanks! |
Hi @nvpohanh . Yes, I am still working on this. The plan is to enable FlashInfer allreduce fusion by default. At the moment, we can't do it due to performance of fallback (for medium and large input sizes). |
This pull request has merge conflicts that must be resolved before it can be |
@ilmarkov Could you re-tune the thresholds using newer FlashInfer with this fix? flashinfer-ai/flashinfer#1507 We found that FlashInfer uses suboptimal cuda grid/block size when calling the AllReduce+RMSNorm kernel. It can lead to >2x perf boost for some cases we tested with. Thanks! |
c269312
to
3e4e159
Compare
I ran benchmarks/kernels/benchmark_fused_collective.py with hidden_dim 1024/8192 TP 2/4 and seq_len 1/2/4/8/16/32/64/128/256/512 again with flashinfer-ai/flashinfer#1507 and found that on B200
no-quant-tp2-hd1024.txt @ilmarkov I think we can simply default to flashinfer oneshot now that it's always faster (on B200 at least)? |
@nvjullin Thank you for the benchmark results. Yes, we default to oneshot. I will update the thresholds accordingly. |
This pull request has merge conflicts that must be resolved before it can be |
ad72306
to
e68e14f
Compare
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Update range based compilation Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
1cd7acf
to
ca9f59e
Compare
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
The failing basic correctness tests and entrypoints are because of torch.standalone_compile issue which was fixed in pytorch/pytorch#157803. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Goo job getting all of this to work! I left a few minor comments. A couple of general things:
- There are a lot of formatting changes, which make the PR hard to read, could you please remove those?
- We should split up the compile ranges part and other allreduce fixes. If it's easier, we can do compile ranges without any allreduce changes and then follow up with everything allreduce related. Alternatively we can just put the custom op matching and other stuff orthogonal to compile ranges in a separate PR to land it first, although I'm working on making that easier anyway.
- TODO: tell Inductor about the compile range (should ask Meta about it)
- instead of having a single dynamic shape graph, could we just use a single (1, max_num_batched_tokens) graph if no split points are passed? And so
compile_range
would never be Optional? Maybe that's something for a follow-up.
torch._inductor.pattern_matcher._seen_patterns.clear() | ||
|
||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool: | ||
def is_applicable_for_range( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that the range is accessible through context anyway, should we just get rid of this method and let each pass exit early inside __call__
? That would simplify the interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's already done inside a __call__
in PostGradPassManager
in unified manner, and it's taken from the context. With this approach, we could easier add tracking which passes are enabled in a graph.
else: | ||
self.compile_ranges.append((split_points[i - 1], s)) | ||
if s in self.compile_sizes: | ||
self.compile_ranges.append((s, s)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If 1 is in compile sizes doesn't it get appended 2x here? It might be good to extract this conversion to a utility method we can test separately.
if isinstance(compile_range, tuple): | ||
# for a specific range of batchsizes, tuning triton kernel parameters | ||
# can be beneficial | ||
config["max_autotune"] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a TODO(luka): max autotune only present with -O3, and this should live in config: https://github.com/vllm-project/vllm/issues/20283
to integers, it also supports "cudagraph_capture_sizes" to | ||
specify the sizes for cudagraph capture.""" | ||
compile_ranges_split_points: Optional[list[int]] = None | ||
"""Split points that represent compile ranges for inductor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mention that compile_sizes also play a role in splitting
GroupShape) | ||
from vllm.platforms import current_platform | ||
from vllm.utils import direct_register_custom_op | ||
from vllm.utils import (_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the FI stuff should live in utils
return [input, weight] | ||
|
||
def register(self, pm_pass: PatternMatcherPass): | ||
if not self.is_custom_rms_norm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like these two should just be separated into separate pattern classes
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: ilmarkov <[email protected]>
0a8c72b
to
82057db
Compare
The PR will be split up following @ProExpertProg recommendation. |
Pull Request Description
This pull request introduces several key enhancements for improving the flexibility and performance of
vllm
, particularly forallreduce
operations and dynamic graph compilation.Summary of Changes
Allreduce Fusion without Custom Ops: Adds support for
allreduce
fusion, even when custom operations are not enabled.Thresholds Adjustment: Adjusts the thresholds for
allreduce
fusion to optimize performance across various configurations.Dynamic Graph Dispatch via
compile_ranges
: Introduces a new configuration option,compile_ranges
, as an alternative tocompile_sizes
. This enables dynamic dispatch to different compiled graphs based on the input batch size.Now with this approach, when allreduce fusion is enabled, vllm adds additional compile range split point in order to separate the graphs:
2 One with nccl based allreduce for large shape inputs
Detailed Breakdown
1. Allreduce Fusion without Custom Operations
This change extends the
allreduce
fusion matching work without custom ops, supporting torch native rmsnorm and quant_fp8 implementations. It allows to keep more performant non fused operations in the graph.2. Dynamic Graph Dispatch with
compile_ranges
The existing
compile_sizes
feature is extended and generalized withcompile_ranges
. Defined by split points, these ranges allow vllm to dynamically dispatch requests to specific, pre-compiled graphs based on input batch size. For example, a configuration of (32, 64) defines three distinct ranges: [1, 32), [32, 64), and [64, max_num_batched_tokens). This provides granular control, allowing developers to statically enable or disable fusions within each graph to optimize performance for different batch sizes.Motivation
Corresponding RFC: #23113
The primary motivation for these changes is to enhance
vllm
's performance and adaptability for diverse workloads. By supportingallreduce
fusion without custom ops and introducing dynamic graph dispatch, we empower users to fine-tunevllm
for more efficient and scalable inference.Update
The PR will be split up following @ProExpertProg recommendation.
First part is torch native ops fusion support: #24248
Second part for adding conditional compilation based on ranges: #24252