Skip to content

Conversation

ilmarkov
Copy link
Contributor

@ilmarkov ilmarkov commented Jul 16, 2025

Purpose

Add integration of FlashInfer fused allreduce RMSNorm and Quantization. Enabled by -O '{"pass_config": {"enable_fi_allreduce_fusion": true}, "level":3, "custom_ops":["+quant_fp8","+rms_norm"]}'

Validation

lm_eval --model local-completions --model_args model=${model},base_url=http://localhost:8000/v1/completions --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k --num_fewshot 5 --batch_size 200

FP8

model RedHatAI/Meta-Llama-3-8B-Instruct-FP8 TP=2
Baseline:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7460|±  | 0.012|
|     |       |strict-match    |     5|exact_match|↑  |0.7483|±  | 0.012|

PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|                                                                                                                                                                                                             
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7513|±  |0.0119|                                                                  
|     |       |strict-match    |     5|exact_match|↑  |0.7528|±  |0.0119|

FP4

model: RedHatAI/Qwen3-32B-NVFP4 TP=2

Baseline

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------|----------------|-----|-----------|---|-----|---|-----|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6634|±  | 0.013|
|     |       |strict-match    |     5|exact_match|↑  |0.7468|±  | 0.012|

PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6687|±  |0.0130|
|     |       |strict-match    |     5|exact_match|↑  |0.7559|±  |0.0118|

Benchmarks

Client:

vllm bench serve \
        --model "$MODEL" \
        --dataset-name sonnet \
        --dataset-path benchmarks/sonnet.txt \
        --request-rate "$qps" \
        --num-prompts $((DURATION_SECONDS * qps)) \

Input len: 550, Output len: 150. B200 GPUs

FP8

RedHatAI/Meta-Llama-3-8B-Instruct-FP8 TP=2

Baseline:

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 14.987 14.301 3.944 3.852
5 16.165 16.135 4.302 4.305
10 17.422 17.458 4.543 4.537

Fused Custom ops -O '{"pass_config": {"enable_fi_allreduce_fusion": false, "enable_noop": true, "enable_fusion": true}, "level":3, "custom_ops":["+quant_fp8","+rms_norm"]}':

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 15.084 14.685 4.150 4.094
5 17.001 17.048 4.533 4.531
10 18.065 17.950 4.781 4.780

Non fused Custom ops -O '{"pass_config": {"enable_fi_allreduce_fusion": false, "enable_noop": false, "enable_fusion": false}, "level":3, "custom_ops":["+quant_fp8","+rms_norm"]}':

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 15.233 14.456 4.249 4.175
5 17.266 17.147 4.634 4.634
10 18.457 18.254 4.904 4.902

PR:
-O '{"pass_config": {"enable_fi_allreduce_fusion": true, "enable_noop": false, "enable_fusion": false}, "level":3, "custom_ops":["+quant_fp8","+rms_norm"]}'

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 14.807 14.384 3.843 3.753
5 16.479 16.508 4.232 4.238
10 17.307 17.100 4.464 4.457

NVFP4

model: RedHatAI/Qwen3-32B-NVFP4 TP=2
Baseline:

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 37.002 35.057 10.754 10.824
5 39.015 38.259 12.020 11.960
10 44.029 42.588 14.114 14.170

Non Fused Custom Ops:
-O '{"pass_config": {"enable_fi_allreduce_fusion": false, "enable_noop": false, "enable_fusion": false}, "level":3, "custom_ops":["+rms_norm"]}'

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 36.409 35.648 11.025 11.102
5 40.440 39.751 12.404 12.353
10 46.272 44.590 14.689 14.761

Fused Custom Ops:
-O '{"pass_config": {"enable_fi_allreduce_fusion": false, "enable_noop": true, "enable_fusion": true}, "level":3, "custom_ops":["+rms_norm"]}'

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 36.458 36.017 11.077 11.155
5 40.235 39.002 12.440 12.377
10 46.241 44.097 14.758 14.789

PR:
-O '{"pass_config": {"enable_fi_allreduce_fusion": true, "enable_noop": false, "enable_fusion": false}, "level":3, "custom_ops":["+rms_norm"]}'

QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 34.068 33.255 10.346 10.449
5 41.480 39.513 11.673 11.648
10 44.979 42.499 13.400 13.411

In the case of FP8 quantization PR improves TPOT compared to Custom ops by ~7-8%, no speedup compared torch.compiled (default) ops. Increases TTFT by up to 8%.
In case of NVFP4 PR improves TPOT by up to 5% compared to torch.compiled ops, TTFT by up to 5%. Comparing to custom ops, PR improves TPOT by up to 10-15%, TTFT by up to 7%.

Profiling

model: RedHatAI/Qwen3-32B-NVFP4 TP=2

Baseline:
image

PR:
image

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for fusing all-reduce, RMSNorm, and quantization operations using FlashInfer. The changes primarily involve adding new pattern matching classes to vllm/compilation/collective_fusion.py and new tests.

My review identified several critical issues in the fusion logic within vllm/compilation/collective_fusion.py. The return value indices from the fused operations in the replacement functions are consistently incorrect across multiple new patterns. These bugs will cause the fusion pass to generate incorrect computation graphs. I've provided specific suggestions to fix these indexing issues. Additionally, there's a minor but important issue with an incorrect pattern_code being used for an FP8 quantization pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return values from the replacement function seem to be incorrect. Based on the mutates_args list ["allreduce_in", "residual", "norm_out", "quant_out", "scale_out"], the auto_functionalized op will return a tuple of 5 tensors in that order. The pattern returns (rms_result, allreduce_output). In the replacement function, rms_result corresponds to the mutated norm_out (allreduce[2]), and allreduce_output corresponds to the mutated allreduce_in (allreduce[0]). However, the current code returns allreduce[3], allreduce[1], which corresponds to (quant_out, residual). This is incorrect and will lead to a faulty fusion.

Suggested change
return allreduce[2], allreduce[0]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return values from the replacement function are incorrect and will cause an IndexError. The auto_functionalized op returns a tuple of 5 elements based on mutates_args, so allreduce[5] is out of bounds. The pattern returns (quant_out, residual_output, output_scale), which corresponds to mutated quant_out, residual, and scale_out. The correct indices should be allreduce[3], allreduce[1], and allreduce[4].

Suggested change
return allreduce[4], allreduce[2], allreduce[5]
return allreduce[3], allreduce[1], allreduce[4]

Comment on lines 574 to 575
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This pattern is for FP8 quantization, but the pattern_code is set to kARResidualRMSNormFP4Quant. This appears to be a copy-paste error and should be corrected to use the FP8 pattern code.

Suggested change
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP4Quant,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP8Quant,

Copy link
Contributor Author

@ilmarkov ilmarkov Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good notice

@nvpohanh
Copy link
Contributor

nvpohanh commented Jul 29, 2025

@ilmarkov To help me understand the roadmap better, will you continue to work on AR+RMSNorm+FP4-quantization fusions in the future? Or will you stop at FP8-quantization? Thanks!

Never mind. I see AllReduceFusedRMSNormStaticQuantNVFP4Pattern now

ilmarkov and others added 13 commits July 29, 2025 03:20
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Upd
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Upd
Signed-off-by: ilmarkov <[email protected]>
Upd
Signed-off-by: ilmarkov <[email protected]>
Fix
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
@mergify mergify bot added the ci/build label Jul 31, 2025
@mgoin mgoin enabled auto-merge (squash) July 31, 2025 20:26
@simon-mo simon-mo disabled auto-merge July 31, 2025 20:58
@simon-mo simon-mo merged commit 6e672da into vllm-project:main Jul 31, 2025
98 of 100 checks passed
@tjtanaa tjtanaa mentioned this pull request Aug 1, 2025
4 tasks
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Co-authored-by: ilmarkov <[email protected]>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Co-authored-by: ilmarkov <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Co-authored-by: ilmarkov <[email protected]>
Signed-off-by: Noam Gat <[email protected]>
@weireweire
Copy link
Contributor

@ilmarkov
Hi I'm going to complete the threshold for AR+NORM fusion, and have a question. I see
HERE we have min of two threshold:

        use_flashinfer = current_tensor_size <= min(
            _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
            max_fusion_size,
        )

and max_fusion_size is calculated base on HERE ,which is actually _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) // self.tp_size, it's alway smaller than the other one _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE) (unless allreduce_in.type != self.model_dtype) ,is this by design? And do we really need to divide tp_size?

@ilmarkov
Copy link
Contributor Author

ilmarkov commented Aug 12, 2025

@weireweire I am working on the thresholds in this PR. We need to divide by tp_size because of one_shot initialization in FlashInfer. It allocates this buffer for each peer rank so we need to make sure that it doesn't fail. Probably, with proper threshold adjustment, we can get rid of this division.

@weireweire
Copy link
Contributor

@ilmarkov thanks, if so can we just change

        use_flashinfer = current_tensor_size <= min(
            _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
            max_fusion_size,
        )

to

        use_flashinfer = (current_tensor_size <= max_fusion_size)

otherwise it seems redundant.

paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Co-authored-by: ilmarkov <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
@weireweire
Copy link
Contributor

could you add a warning HERE to make user enable noop pass like other backend, otherwise fusion won't happen.

diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Co-authored-by: ilmarkov <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
@ilmarkov
Copy link
Contributor Author

@weireweire Sure, I'll add it in the next PR. Thanks for the note!

epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Co-authored-by: ilmarkov <[email protected]>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: ilmarkov <[email protected]>
Signed-off-by: ilmarkov <[email protected]>
Co-authored-by: ilmarkov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants