Skip to content

Conversation

louiswang524
Copy link

@louiswang524 louiswang524 commented Sep 7, 2025

Summary

Fixes PyTorch 2.8 API compatibility issue in collective fusion patterns by adding the missing orig_scatter_dim
parameter to torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter calls.

Problem

PyTorch 2.8 introduced a breaking change where fused_scaled_matmul_reduce_scatter now requires an
orig_scatter_dim parameter. This was causing compilation test failures in tests/compile/test_async_tp.py.

Solution

Added orig_scatter_dim=0 parameter to both instances of the function call in:

  • ScaledMMReduceScatterPattern.replacement() (line 168)
  • CutlassScaledMMReduceScatterPattern.replacement() (line 281)

Testing

The fix maintains backward compatibility and addresses the specific error mentioned in issue #24376.

Fixes #24376

Copy link

github-actions bot commented Sep 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@louiswang524 louiswang524 mentioned this pull request Sep 7, 2025
1 task
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an API compatibility issue with PyTorch 2.8 by adding the orig_scatter_dim parameter to fused_scaled_matmul_reduce_scatter calls. While the changes are correct for PyTorch 2.8, they break backward compatibility with older versions of PyTorch, which will cause runtime errors. My review includes critical feedback on how to address this by adding the new parameter conditionally based on the PyTorch version, ensuring the fix works across different environments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While this change fixes compatibility with PyTorch 2.8, it introduces a backward compatibility issue with older PyTorch versions (e.g., 2.7) that do not have the orig_scatter_dim argument. This will raise a TypeError on older versions.

To ensure backward compatibility, this argument should be added conditionally. I recommend refactoring the call to use a kwargs dictionary and add orig_scatter_dim only if the PyTorch version is 2.8 or newer.

Example:

from vllm.utils import is_torch_equal_or_newer

kwargs = {
    "input": input,
    "mat2": mat2,
    "scale_a": scale_a,
    "scale_b": scale_b,
    "reduce_op": "avg",
    "scatter_dim": 0,
    "out_dtype": self.dtype,
    "group_name": self.tp.device_group.group_name,
}
if is_torch_equal_or_newer("2.8"):
    kwargs["orig_scatter_dim"] = 0

gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(**kwargs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change, while fixing compatibility for PyTorch 2.8, breaks backward compatibility for older versions like 2.7. The orig_scatter_dim argument does not exist in fused_scaled_matmul_reduce_scatter in older PyTorch versions, which will lead to a TypeError. This argument should be added conditionally based on the PyTorch version to maintain compatibility.

Example:

from vllm.utils import is_torch_equal_or_newer

kwargs = {
    # ... other args
}
if is_torch_equal_or_newer("2.8"):
    kwargs["orig_scatter_dim"] = 0

gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(**kwargs)

Correct the function signature to match PyTorch 2.8 requirements.

The issue was not a missing orig_scatter_dim parameter, but rather:
1. Wrong parameter order (using named parameters instead of positional)
2. Missing required parameters: bias_node, result_scale_node, use_fast_accum

PyTorch 2.8 signature:
fused_scaled_matmul_reduce_scatter(A, B, A_scale, B_scale, reduce_op,
                                  scatter_dim, group_name, bias_node,
                                  result_scale_node, out_dtype, use_fast_accum)

Changes:
- Convert scatter_dim=0 to positional argument 0
- Move group_name to correct position (7th parameter)
- Add missing bias_node=None (8th parameter)
- Add missing result_scale_node=None (9th parameter)
- Move out_dtype to correct position (10th parameter)
- Add missing use_fast_accum=False (11th parameter)

Fixes vllm-project#24376
@louiswang524 louiswang524 force-pushed the fix-pytorch-28-symm-mem-api branch from 975b032 to 27798fb Compare September 7, 2025 08:23
@louiswang524
Copy link
Author

Closing due to incorrect analysis of the issue

@louiswang524 louiswang524 deleted the fix-pytorch-28-symm-mem-api branch September 7, 2025 08:32
@andoorve
Copy link
Collaborator

Hey @louiswang524 why was the analysis incorrect?

cc: @jasonlizhengjian

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: tests/compile failures

2 participants