-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Misc] Integration of FlashInfer fused AR+RMS+FP8 #21741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Integration of FlashInfer fused AR+RMS+FP8 #21741
Conversation
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
Signed-off-by: Omer Ullman Argov <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request integrates FlashInfer's fused All-Reduce, RMSNorm, and FP8 quantization kernel, which is a great step towards improving performance for FP8 models. The changes are well-structured, and the expansion of the test suite to cover FP8 and various sequence lengths is comprehensive.
My main concern is a potential bug in the AllReduceFusedAddRMSNormFP8Pattern
. It appears to be missing a necessary workaround for a graph topological sorting issue that is present in a similar pattern within the same file. This could be the root cause of the correctness issue mentioned in the pull request description. Addressing this should be the top priority.
Once the correctness issue is resolved, this will be a valuable contribution.
def replacement(residual: torch.Tensor, input: torch.Tensor, | ||
weight: torch.Tensor, scale: torch.Tensor, | ||
rms_result: torch.Tensor): | ||
allreduce = auto_functionalized( | ||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm_fp8. | ||
default, | ||
allreduce_in=input, | ||
residual=residual, | ||
rms_gamma=weight, | ||
rms_eps=self.epsilon, | ||
quant_out=rms_result, | ||
scale=scale, | ||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), | ||
) | ||
return allreduce[3], allreduce[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The replacement
function in AllReduceFusedAddRMSNormFP8Pattern
seems to have a potential bug related to graph topology that could be the source of the correctness issue mentioned in the PR description.
Specifically, it uses rms_result
as the quant_out
buffer for the fused operation. However, in the FX graph, the allocation for rms_result
(the output of the original quantization op) likely occurs after the all_reduce
operation. Since this pattern replaces the all_reduce
node, the new fused op is inserted at a point in the graph where rms_result
has not yet been allocated, leading to a topologically invalid graph.
This same issue is handled in AllReduceRMSNORMFP8Pattern
(lines 463-469) by creating a new empty
tensor as a workaround. A similar fix should be applied here to ensure the graph remains valid after the replacement.
# the allocation of rms_fp8_result appears after the all_reduce.
# however, we are replacing the all_reduce call,
# so our whole fused op appears before the allocation the result.
# this makes the graph topologically unsorted and causes errors,
# since our fused op references a tensor that is not yet allocated.
# therefore we need to allocate a result tensor as a work-around.
empty = torch.empty_like(input, dtype=current_platform.fp8_dtype())
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Purpose
This PR builds on the great work of @ilmarkov in #20691 and integrates flashinfer's allreduce + RMSNorm + FP8 Quant fusion.
This new fusion relies on the existing RMNorm + FP8 Quant fusion in order to not collide with the non-FP8 fusion. Therefore, to enable this fusion we must also enable
fusion
andnoop
:--compilation-config='{"pass_config": {"enable_noop": true, "enable_fusion": true, "enable_flashinfer_allreduce_fusion": true}, "custom_ops": ["+rms_norm", "+quant_fp8"], "level":3}'
In addition, this PR modifies so of the guards that determined when the flashinfer path should be taken, and matches them to the TRTLLM source.
Test Plan
tests/compile/test_fusion_all_reduce.py
was expanded to also test modules with FP8 quant, as well as multiple sequence lengths to test all possible execution paths (FI oneshot, FI twoshot, non-FI).Test Result
Llama-3.3-70B-Instruct-FP8 (modelopt) TP=4 on B200 GPUs.
Flashinfer commit
7253d74
.Benchmark E2E:
Client
Server
GSM8K
Client