-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Add FlashInfer allreduce RMSNorm Quant fusion #21069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FlashInfer allreduce RMSNorm Quant fusion #21069
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for fusing all-reduce, RMSNorm, and quantization operations using FlashInfer. The changes primarily involve adding new pattern matching classes to vllm/compilation/collective_fusion.py
and new tests.
My review identified several critical issues in the fusion logic within vllm/compilation/collective_fusion.py
. The return value indices from the fused operations in the replacement
functions are consistently incorrect across multiple new patterns. These bugs will cause the fusion pass to generate incorrect computation graphs. I've provided specific suggestions to fix these indexing issues. Additionally, there's a minor but important issue with an incorrect pattern_code
being used for an FP8 quantization pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return values from the replacement
function seem to be incorrect. Based on the mutates_args
list ["allreduce_in", "residual", "norm_out", "quant_out", "scale_out"]
, the auto_functionalized
op will return a tuple of 5 tensors in that order. The pattern returns (rms_result, allreduce_output)
. In the replacement
function, rms_result
corresponds to the mutated norm_out
(allreduce[2]
), and allreduce_output
corresponds to the mutated allreduce_in
(allreduce[0]
). However, the current code returns allreduce[3], allreduce[1]
, which corresponds to (quant_out, residual)
. This is incorrect and will lead to a faulty fusion.
return allreduce[2], allreduce[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return values from the replacement
function are incorrect and will cause an IndexError
. The auto_functionalized
op returns a tuple of 5 elements based on mutates_args
, so allreduce[5]
is out of bounds. The pattern returns (quant_out, residual_output, output_scale)
, which corresponds to mutated quant_out
, residual
, and scale_out
. The correct indices should be allreduce[3]
, allreduce[1]
, and allreduce[4]
.
return allreduce[4], allreduce[2], allreduce[5] | |
return allreduce[3], allreduce[1], allreduce[4] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern is for FP8 quantization, but the pattern_code
is set to kARResidualRMSNormFP4Quant
. This appears to be a copy-paste error and should be corrected to use the FP8 pattern code.
pattern_code=flashinfer_comm.AllReduceFusionPattern. | |
kARResidualRMSNormFP4Quant, | |
pattern_code=flashinfer_comm.AllReduceFusionPattern. | |
kARResidualRMSNormFP8Quant, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good notice
f345499
to
4eb3753
Compare
@ilmarkov To help me understand the roadmap better, will you continue to work on AR+RMSNorm+FP4-quantization fusions in the future? Or will you stop at FP8-quantization? Thanks! Never mind. I see |
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]> Signed-off-by: ilmarkov <[email protected]> Co-authored-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]> Signed-off-by: ilmarkov <[email protected]> Co-authored-by: ilmarkov <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
Signed-off-by: ilmarkov <[email protected]> Signed-off-by: ilmarkov <[email protected]> Co-authored-by: ilmarkov <[email protected]> Signed-off-by: Noam Gat <[email protected]>
@ilmarkov
and max_fusion_size is calculated base on HERE ,which is actually |
@weireweire I am working on the thresholds in this PR. We need to divide by tp_size because of one_shot initialization in FlashInfer. It allocates this buffer for each peer rank so we need to make sure that it doesn't fail. Probably, with proper threshold adjustment, we can get rid of this division. |
@ilmarkov thanks, if so can we just change
to
otherwise it seems redundant. |
Signed-off-by: ilmarkov <[email protected]> Signed-off-by: ilmarkov <[email protected]> Co-authored-by: ilmarkov <[email protected]> Signed-off-by: Paul Pak <[email protected]>
could you add a warning HERE to make user enable noop pass like other backend, otherwise fusion won't happen. |
Signed-off-by: ilmarkov <[email protected]> Signed-off-by: ilmarkov <[email protected]> Co-authored-by: ilmarkov <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
@weireweire Sure, I'll add it in the next PR. Thanks for the note! |
Signed-off-by: ilmarkov <[email protected]> Signed-off-by: ilmarkov <[email protected]> Co-authored-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]> Signed-off-by: ilmarkov <[email protected]> Co-authored-by: ilmarkov <[email protected]>
Purpose
Add integration of FlashInfer fused allreduce RMSNorm and Quantization. Enabled by
-O '{"pass_config": {"enable_fi_allreduce_fusion": true}, "level":3, "custom_ops":["+quant_fp8","+rms_norm"]}'
Validation
lm_eval --model local-completions --model_args model=${model},base_url=http://localhost:8000/v1/completions --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k --num_fewshot 5 --batch_size 200
FP8
model RedHatAI/Meta-Llama-3-8B-Instruct-FP8 TP=2
Baseline:
PR:
FP4
model: RedHatAI/Qwen3-32B-NVFP4 TP=2
Baseline
PR:
Benchmarks
Client:
Input len: 550, Output len: 150. B200 GPUs
FP8
RedHatAI/Meta-Llama-3-8B-Instruct-FP8 TP=2
Baseline:
Fused Custom ops
-O '{"pass_config": {"enable_fi_allreduce_fusion": false, "enable_noop": true, "enable_fusion": true}, "level":3, "custom_ops":["+quant_fp8","+rms_norm"]}'
:Non fused Custom ops
-O '{"pass_config": {"enable_fi_allreduce_fusion": false, "enable_noop": false, "enable_fusion": false}, "level":3, "custom_ops":["+quant_fp8","+rms_norm"]}'
:PR:
-O '{"pass_config": {"enable_fi_allreduce_fusion": true, "enable_noop": false, "enable_fusion": false}, "level":3, "custom_ops":["+quant_fp8","+rms_norm"]}'
NVFP4
model: RedHatAI/Qwen3-32B-NVFP4 TP=2
Baseline:
Non Fused Custom Ops:
-O '{"pass_config": {"enable_fi_allreduce_fusion": false, "enable_noop": false, "enable_fusion": false}, "level":3, "custom_ops":["+rms_norm"]}'
Fused Custom Ops:
-O '{"pass_config": {"enable_fi_allreduce_fusion": false, "enable_noop": true, "enable_fusion": true}, "level":3, "custom_ops":["+rms_norm"]}'
PR:
-O '{"pass_config": {"enable_fi_allreduce_fusion": true, "enable_noop": false, "enable_fusion": false}, "level":3, "custom_ops":["+rms_norm"]}'
In the case of FP8 quantization PR improves TPOT compared to Custom ops by ~7-8%, no speedup compared torch.compiled (default) ops. Increases TTFT by up to 8%.
In case of NVFP4 PR improves TPOT by up to 5% compared to torch.compiled ops, TTFT by up to 5%. Comparing to custom ops, PR improves TPOT by up to 10-15%, TTFT by up to 7%.
Profiling
model: RedHatAI/Qwen3-32B-NVFP4 TP=2
Baseline:

PR:
