Skip to content

Conversation

omera-nv
Copy link

@omera-nv omera-nv commented Jul 28, 2025

  • This is still a draft until the correctness isue is fixed! *

Purpose

This PR builds on the great work of @ilmarkov in #20691 and integrates flashinfer's allreduce + RMSNorm + FP8 Quant fusion.
This new fusion relies on the existing RMNorm + FP8 Quant fusion in order to not collide with the non-FP8 fusion. Therefore, to enable this fusion we must also enable fusion and noop: --compilation-config='{"pass_config": {"enable_noop": true, "enable_fusion": true, "enable_flashinfer_allreduce_fusion": true}, "custom_ops": ["+rms_norm", "+quant_fp8"], "level":3}'

In addition, this PR modifies so of the guards that determined when the flashinfer path should be taken, and matches them to the TRTLLM source.

Test Plan

tests/compile/test_fusion_all_reduce.py was expanded to also test modules with FP8 quant, as well as multiple sequence lengths to test all possible execution paths (FI oneshot, FI twoshot, non-FI).

Test Result

Llama-3.3-70B-Instruct-FP8 (modelopt) TP=4 on B200 GPUs.
Flashinfer commit 7253d74.

Benchmark E2E:

Client

DURATION_SECONDS=60;
vllm bench serve --model nvidia/Llama-3.3-70B-Instruct-FP8 --dataset-name sonnet --dataset-path benchmarks/sonnet.txt --request-rate "$qps" --num-prompts $((DURATION_SECONDS * qps))`

Server

  • No Fusion:
python -m vllm.entrypoints.openai.api_server --disable-log-requests --no-enable-prefix-caching --dtype auto --kv-cache-dtype fp8 --quantization modelopt -tp 4 --model nvidia/Llama-3.3-70B-Instruct-FP8
QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 39.84 39.92 10.66 10.70
5 44.10 42.62 11.44 11.36
10 49.24 45.67 14.10 14.16
  • AR+RMS:
python -m vllm.entrypoints.openai.api_server --disable-log-requests --no-enable-prefix-caching --dtype auto --kv-cache-dtype fp8 --quantization modelopt -tp 4 --model nvidia/Llama-3.3-70B-Instruct-FP8 --compilation-config '{"pass_config": {"enable_noop": true, "enable_fusion": true, "enable_fi_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level": 3}'
QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 41.73 42.04 9.65 9.63
5 45.80 44.08 10.91 10.85
10 51.87 48.03 13.82 13.88
  • AR+RMS+FP8:
vllm serve --disable-log-requests --no-enable-prefix-caching --dtype auto --kv-cache-dtype fp8 --quantization modelopt -tp 4 --model nvidia/Llama-3.3-70B-Instruct-FP8 --compilation-config '{"pass_config": {"enable_noop": true, "enable_fusion": true, "enable_fi_allreduce_fusion": true}, "custom_ops": ["+rms_norm", "+quant_fp8"], "level": 3}'
QPS Mean TTFT (ms) Median TTFT (ms) Mean TPOT (ms) Median TPOT (ms)
1 40.09 39.66 9.72 9.68
5 45.31 43.53 10.93 10.85
10 50.94 47.10 13.37 13.36

GSM8K

Client

lm_eval --model local-completions --tasks gsm8k --model_args model=nvidia/Llama-3.3-70B-Instruct-FP8,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False
Filter No Fusion AR+RMS AR+RMS+FP8
flexible-extract 0.9393 0.9386 0.2790
strict-match 0.7612 0.7817 0.0273

omera-nv added 13 commits July 27, 2025 18:08
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates FlashInfer's fused All-Reduce, RMSNorm, and FP8 quantization kernel, which is a great step towards improving performance for FP8 models. The changes are well-structured, and the expansion of the test suite to cover FP8 and various sequence lengths is comprehensive.

My main concern is a potential bug in the AllReduceFusedAddRMSNormFP8Pattern. It appears to be missing a necessary workaround for a graph topological sorting issue that is present in a similar pattern within the same file. This could be the root cause of the correctness issue mentioned in the pull request description. Addressing this should be the top priority.

Once the correctness issue is resolved, this will be a valuable contribution.

Comment on lines +582 to +596
def replacement(residual: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, scale: torch.Tensor,
rms_result: torch.Tensor):
allreduce = auto_functionalized(
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm_fp8.
default,
allreduce_in=input,
residual=residual,
rms_gamma=weight,
rms_eps=self.epsilon,
quant_out=rms_result,
scale=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
return allreduce[3], allreduce[2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The replacement function in AllReduceFusedAddRMSNormFP8Pattern seems to have a potential bug related to graph topology that could be the source of the correctness issue mentioned in the PR description.

Specifically, it uses rms_result as the quant_out buffer for the fused operation. However, in the FX graph, the allocation for rms_result (the output of the original quantization op) likely occurs after the all_reduce operation. Since this pattern replaces the all_reduce node, the new fused op is inserted at a point in the graph where rms_result has not yet been allocated, leading to a topologically invalid graph.

This same issue is handled in AllReduceRMSNORMFP8Pattern (lines 463-469) by creating a new empty tensor as a workaround. A similar fix should be applied here to ensure the graph remains valid after the replacement.

            # the allocation of rms_fp8_result appears after the all_reduce.
            # however, we are replacing the all_reduce call,
            # so our whole fused op appears before the allocation the result.
            # this makes the graph topologically unsorted and causes errors,
            # since our fused op references a tensor that is not yet allocated.
            # therefore we need to allocate a result tensor as a work-around.
            empty = torch.empty_like(input, dtype=current_platform.fp8_dtype())

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@ilmarkov
Copy link
Contributor

@omera-nv There is already a PR(#21069) that add FP8/NVFP4 quantization.

@omera-nv omera-nv closed this Jul 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants