Skip to content

Conversation

jasonlizhengjian
Copy link
Contributor

@jasonlizhengjian jasonlizhengjian commented Oct 1, 2025

Purpose

Fixes #24376 part 2.d

Updated torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter calls to match the new PyTorch API signature. The function signature changed from PyTorch 2.7.1 to require additional positional parameters.

Changes:

  • Added orig_scatter_dim and scatter_dim_after_maybe_reshape as positional parameters
  • Added output_shape calculation: [*input.shape[:-1], mat2.shape[1]]
  • Changed all optional parameters (bias, result_scale, out_dtype, use_fast_accum) from keyword arguments to positional arguments to match PyTorch's torch._inductor implementation

References:

Test Plan

tests/compile/test_async_tp.py

Test Result

pytest -s tests/compile/test_async_tp.py on 2 x B200

================== 16 passed, 9 warnings in 350.92s (0:05:50) ==================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@jasonlizhengjian jasonlizhengjian changed the title [BugFix][torch.compile] Fix fused_scaled_matmul_reduce_scatter signature for PyTorch 2.8 for issue #24376 [BugFix][torch.compile] Fix fused_scaled_matmul_reduce_scatter signature for PyTorch 2.8 Oct 1, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a breaking change in the torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter function signature introduced in PyTorch 2.8. The changes correctly update the function calls in ScaledMMReduceScatterPattern and CutlassScaledMMReduceScatterPattern to use the new positional arguments, ensuring compatibility with the latest PyTorch version. The implementation is a direct and accurate adaptation of the new API. The changes are correct and necessary. I have no further comments.

@jasonlizhengjian jasonlizhengjian force-pushed the fix/fused-scaled-matmul-signature branch from 2072f4e to 49a3b8a Compare October 1, 2025 18:10
@jasonlizhengjian
Copy link
Contributor Author

@cascade812 can you also review

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the test file to PyTorch Compilation Unit Tests in test-pipeline.yml?

@jasonlizhengjian
Copy link
Contributor Author

Can we add the test file to PyTorch Compilation Unit Tests in test-pipeline.yml?

@ProExpertProg

test_async_tp.py is already there but it's currently being skipped due to only being run with 1 GPU:

- label: PyTorch Compilation Unit Tests # 15min
  timeout_in_minutes: 30
  mirror_hardwares: [amdexperimental]
  torch_nightly: true
  source_file_dependencies:
    - vllm/
    - tests/compile
  commands:
    - pytest -v -s compile/test_pass_manager.py
    - pytest -v -s compile/test_fusion.py
    - pytest -v -s compile/test_fusion_attn.py
    - pytest -v -s compile/test_silu_mul_quant_fusion.py
    - pytest -v -s compile/test_sequence_parallelism.py
    - pytest -v -s compile/test_async_tp.py
    - pytest -v -s compile/test_fusion_all_reduce.py
    - pytest -v -s compile/test_decorator.py
    - pytest -v -s compile/test_noop_elimination.py

for example in this ci run https://buildkite.com/vllm/ci/builds/32809/steps/canvas?jid=019995c2-21ec-4c60-ba64-5052e2ada32f it outputs:

[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype0-16-16-8-TestScaledMMRSModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype0-16-16-8-TestAGScaledMMModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype0-16-16-8-TestCutlassScaledMMRSModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype0-16-16-8-TestAGCutlassScaledMMModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype1-16-16-8-TestMMRSModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype1-16-16-8-TestAGMMModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype1-16-16-8-TestScaledMMRSModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype1-16-16-8-TestAGScaledMMModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype1-16-16-8-TestCutlassScaledMMRSModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_replace[dtype1-16-16-8-TestAGCutlassScaledMMModel] SKIPPED
[2025-09-29T14:42:06Z] compile/test_async_tp.py::test_async_tp_pass_correctness[False-mp-True-2-meta-llama/Llama-3.2-1B-Instruct] Fork a new process to run a test 772
[2025-09-29T14:42:06Z] Fork a new process to run a test 0
[2025-09-29T14:42:06Z] Need at least 2 x 1 GPUs
[2025-09-29T14:42:07Z] PASSED
[2025-09-29T14:42:07Z] compile/test_async_tp.py::test_async_tp_pass_correctness[False-mp-True-2-RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8] Fork a new process to run a test 773
[2025-09-29T14:42:07Z] Fork a new process to run a test 0
[2025-09-29T14:42:07Z] Need at least 2 x 1 GPUs
[2025-09-29T14:42:07Z] PASSED
[2025-09-29T14:42:07Z] compile/test_async_tp.py::test_async_tp_pass_correctness[True-mp-True-2-meta-llama/Llama-3.2-1B-Instruct] Fork a new process to run a test 774
[2025-09-29T14:42:07Z] Fork a new process to run a test 0
[2025-09-29T14:42:07Z] Need at least 2 x 1 GPUs
[2025-09-29T14:42:07Z] PASSED
[2025-09-29T14:42:07Z] compile/test_async_tp.py::test_async_tp_pass_correctness[True-mp-True-2-RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8] Fork a new process to run a test 775
[2025-09-29T14:42:07Z] Fork a new process to run a test 0
[2025-09-29T14:42:07Z] Need at least 2 x 1 GPUs
[2025-09-29T14:42:07Z] PASSED

Not sure why some of these say PASSED despite having the output from the skip:

    if num_gpus_available < tp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")

@andoorve
Copy link
Collaborator

andoorve commented Oct 1, 2025

cc: @louiswang524 re: #24393

@cascade812
Copy link
Contributor

@jasonlizhengjian thanks, LGTM

@cascade812
Copy link
Contributor

Regarding the skipped test, it needs to be run on multiple GPUs. Can we move it to the Distributed Tests (2 GPUs) section? Alternatively, a new PyTorch Compilation Unit Tests (with 2 GPUs) section would also work.

@tejas-srikanth
Copy link

@jasonlizhengjian Thank you for taking care of this, LGTM!

@ProExpertProg
Copy link
Collaborator

Yeah please move the test to distributed tests for now, we can rearrange later

Updated torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter calls to match
the new PyTorch API signature. The function signature changed from PyTorch 2.7.1
to require additional positional parameters.

Changes:
- Added orig_scatter_dim and scatter_dim_after_maybe_reshape as positional parameters
- Added output_shape calculation: [*input.shape[:-1], mat2.shape[1]]
- Changed all optional parameters (bias, result_scale, out_dtype, use_fast_accum)
  from keyword arguments to positional arguments to match PyTorch's torch._inductor
  implementation

References:
- PyTorch function definition: torch/distributed/_symmetric_memory/__init__.py:454-461
- PyTorch test usage: test/distributed/test_symmetric_memory.py:579-590
- PyTorch inductor usage: torch/_inductor/fx_passes/micro_pipeline_tp.py:816-834

Signed-off-by: jasonlizhengjian <[email protected]>
Moved compile/test_async_tp.py from Compilation Tests to Distributed Tests
(2 GPUs) section as it requires 2 GPUs to run (@multi_gpu_test decorator).

Also added tests/compile/test_async_tp.py to source_file_dependencies.

Signed-off-by: jasonlizhengjian <[email protected]>
Signed-off-by: Jason Li <[email protected]>

Signed-off-by: jasonlizhengjian <[email protected]>
@jasonlizhengjian jasonlizhengjian force-pushed the fix/fused-scaled-matmul-signature branch from 26c0001 to 8fe845f Compare October 5, 2025 17:59
Moved compile/test_sequence_parallelism.py from Compilation Tests to
Distributed Tests (2 GPUs) section as it requires 2 GPUs to run
(@multi_gpu_test decorator).

Also added tests/compile/test_sequence_parallelism.py to
source_file_dependencies.

Signed-off-by: jasonlizhengjian <[email protected]>
@ProExpertProg ProExpertProg enabled auto-merge (squash) October 7, 2025 16:46
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 7, 2025
The test_async_tp.py uses PyTorch's symmetric memory operations
(torch.ops.symm_mem.fused_matmul_reduce_scatter and
fused_all_gather_matmul) which require Hopper architecture (H100/H200)
or newer GPUs with multicast and fabric support.

L4 GPUs (Ada Lovelace, compute capability 8.9) lack these features,
causing "invalid device ordinal" errors during symmetric memory
rendezvous when the test runs on the default L4 runners.

Changes:
- Removed test_async_tp.py from "Distributed Tests (2 GPUs)" section
  which runs on L4 GPUs
- Added test_async_tp.py to "Distributed Tests (H200)" section which
  runs on H200 GPUs with full symmetric memory support
- test_sequence_parallelism.py remains on L4 as it uses standard NCCL
  collectives that don't require symmetric memory

Signed-off-by: jasonlizhengjian <[email protected]>

Signed-off-by:  <>
auto-merge was automatically disabled October 7, 2025 20:34

Head branch was pushed to by a user without write access

@jasonlizhengjian
Copy link
Contributor Author

@ProExpertProg the Async TP test was failing in distributed tests https://buildkite.com/vllm/ci/builds/33859/steps/canvas?sid=0199bf91-e585-47ed-adff-0300604a3c12#0199bf91-e674-4667-b280-e1c472cebec7/165-4049
since it seemed like there was some issue with torch symmetric memory and L4 GPUs.
I moved it to the H200 section for now

@ProExpertProg ProExpertProg enabled auto-merge (squash) October 7, 2025 20:46
…ection

Both test_async_tp.py and test_sequence_parallelism.py use PyTorch's
symmetric memory operations which require Hopper architecture (H100/H200)
or newer GPUs with multicast and fabric support.

L4 GPUs (Ada Lovelace, compute capability 8.9) lack these features,
causing "invalid device ordinal" errors during symmetric memory
rendezvous when the tests run on the default L4 runners.

Changes:
- Removed test_async_tp.py from "Distributed Tests (2 GPUs)" section
  which runs on L4 GPUs
- Removed test_sequence_parallelism.py from "Distributed Tests (2 GPUs)"
  section which runs on L4 GPUs
- Added both tests to "Distributed Tests (H200)" section which runs on
  H200 GPUs with full symmetric memory support

Signed-off-by: jasonlizhengjian <[email protected]>

Signed-off-by:  <>
auto-merge was automatically disabled October 8, 2025 01:08

Head branch was pushed to by a user without write access

@jasonlizhengjian
Copy link
Contributor Author

@ProExpertProg I had to move seq parallel test to the H200 section as well due to the fp8 parts

@ProExpertProg ProExpertProg enabled auto-merge (squash) October 9, 2025 00:40
Signed-off-by: Luka Govedič <[email protected]>
@ProExpertProg
Copy link
Collaborator

I reordered tests because CP is failing - let's see that async tp and seq par tests pass ✅

@jasonlizhengjian
Copy link
Contributor Author

seems like the async tp and seq par tests are passing now

@vllm-bot vllm-bot merged commit f4ba206 into vllm-project:main Oct 10, 2025
45 of 47 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: tests/compile failures

6 participants