-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Fix PyTorch 2.8 API compatibility for symm_mem.fused_scaled_matmul_reduce_scatter #24393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PyTorch 2.8 API compatibility for symm_mem.fused_scaled_matmul_reduce_scatter #24393
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix an API compatibility issue with PyTorch 2.8 by adding the orig_scatter_dim
parameter to fused_scaled_matmul_reduce_scatter
calls. While the changes are correct for PyTorch 2.8, they break backward compatibility with older versions of PyTorch, which will cause runtime errors. My review includes critical feedback on how to address this by adding the new parameter conditionally based on the PyTorch version, ensuring the fix works across different environments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this change fixes compatibility with PyTorch 2.8, it introduces a backward compatibility issue with older PyTorch versions (e.g., 2.7) that do not have the orig_scatter_dim
argument. This will raise a TypeError
on older versions.
To ensure backward compatibility, this argument should be added conditionally. I recommend refactoring the call to use a kwargs
dictionary and add orig_scatter_dim
only if the PyTorch version is 2.8 or newer.
Example:
from vllm.utils import is_torch_equal_or_newer
kwargs = {
"input": input,
"mat2": mat2,
"scale_a": scale_a,
"scale_b": scale_b,
"reduce_op": "avg",
"scatter_dim": 0,
"out_dtype": self.dtype,
"group_name": self.tp.device_group.group_name,
}
if is_torch_equal_or_newer("2.8"):
kwargs["orig_scatter_dim"] = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(**kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change, while fixing compatibility for PyTorch 2.8, breaks backward compatibility for older versions like 2.7. The orig_scatter_dim
argument does not exist in fused_scaled_matmul_reduce_scatter
in older PyTorch versions, which will lead to a TypeError
. This argument should be added conditionally based on the PyTorch version to maintain compatibility.
Example:
from vllm.utils import is_torch_equal_or_newer
kwargs = {
# ... other args
}
if is_torch_equal_or_newer("2.8"):
kwargs["orig_scatter_dim"] = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(**kwargs)
Correct the function signature to match PyTorch 2.8 requirements. The issue was not a missing orig_scatter_dim parameter, but rather: 1. Wrong parameter order (using named parameters instead of positional) 2. Missing required parameters: bias_node, result_scale_node, use_fast_accum PyTorch 2.8 signature: fused_scaled_matmul_reduce_scatter(A, B, A_scale, B_scale, reduce_op, scatter_dim, group_name, bias_node, result_scale_node, out_dtype, use_fast_accum) Changes: - Convert scatter_dim=0 to positional argument 0 - Move group_name to correct position (7th parameter) - Add missing bias_node=None (8th parameter) - Add missing result_scale_node=None (9th parameter) - Move out_dtype to correct position (10th parameter) - Add missing use_fast_accum=False (11th parameter) Fixes vllm-project#24376
975b032
to
27798fb
Compare
Closing due to incorrect analysis of the issue |
Hey @louiswang524 why was the analysis incorrect? |
Summary
Fixes PyTorch 2.8 API compatibility issue in collective fusion patterns by adding the missing
orig_scatter_dim
parameter to
torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter
calls.Problem
PyTorch 2.8 introduced a breaking change where
fused_scaled_matmul_reduce_scatter
now requires anorig_scatter_dim
parameter. This was causing compilation test failures intests/compile/test_async_tp.py
.Solution
Added
orig_scatter_dim=0
parameter to both instances of the function call in:ScaledMMReduceScatterPattern.replacement()
(line 168)CutlassScaledMMReduceScatterPattern.replacement()
(line 281)Testing
The fix maintains backward compatibility and addresses the specific error mentioned in issue #24376.
Fixes #24376