Skip to content

Conversation

cascade812
Copy link
Contributor

@cascade812 cascade812 commented Jun 26, 2025

This PR adds torch async tp using compilation pass for scaled mm.
It builds upon previous work to extend async tensor parallelism support to quantized models.

It requires below config to run

config = CompilationConfig(
    level=3,
    compile_sizes=[4, 8, 16],
    splitting_ops=[],
)
config.pass_config.enable_noop = True
config.pass_config.enable_async_tp= True

llm = LLM(
    model="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
    enforce_eager=False,
    tensor_parallel_size=2,
    compilation_config=config)

On H100x4, 70B model with async_tp enabled has 5% reduce in avg latency when input-len=8192.

python benchmarks/benchmark_latency.py --model RedHatAI/Meta-Llama-3.1-70B-Instruct-FP8 --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 4 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [8192], "pass_config": {"enable_async_tp": false}}' --no-enable-prefix-caching &> benchmark.log
Avg latency: 0.38038901427062227 seconds
10% percentile latency: 0.3801612737996038 seconds
25% percentile latency: 0.38024083150958177 seconds
50% percentile latency: 0.3802969330281485 seconds
75% percentile latency: 0.38051927399646956 seconds
90% percentile latency: 0.3806809635949321 seconds
99% percentile latency: 0.38095860722940417 seconds

python benchmarks/benchmark_latency.py --model RedHatAI/Meta-Llama-3.1-70B-Instruct-FP8 --output-len 1 --input-len 8192 --batch-size 1 --tensor-parallel-size 4 --load-format dummy --num_iters_warmup 5 --num_iters 15 -O '{"level":3, "compile_sizes": [8192], "pass_config": {"enable_async_tp": true}}' --no-enable-prefix-caching &> benchmark.log
Avg latency: 0.3639305369268792 seconds
10% percentile latency: 0.36311788139864803 seconds
25% percentile latency: 0.363301858495106 seconds
50% percentile latency: 0.3638122149859555 seconds
75% percentile latency: 0.3640824829781195 seconds
90% percentile latency: 0.3652374107914511 seconds
99% percentile latency: 0.3660336719558109 seconds

Signed-off-by: cascade812 <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @cascade812, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the asynchronous tensor parallelism (async TP) compilation pass by adding support for _scaled_mm operations. This is particularly important for optimizing the performance and memory efficiency of quantized models, especially those using FP8. The changes involve introducing new fusion patterns that combine _scaled_mm with collective operations like reduce_scatter and all_gather, and validating these new capabilities with dedicated test cases and an FP8 quantized model.

Highlights

  • Expanded Async TP Support: The asynchronous tensor parallelism (async TP) compilation pass now includes support for torch._scaled_mm operations, which are essential for efficient execution of quantized models.
  • New Fusion Patterns: Introduced ScaledMMReduceScatterPattern and AllGatherScaledMMPattern to fuse _scaled_mm with reduce_scatter and all_gather operations, respectively. These fusions optimize collective communication and computation in distributed settings.
  • FP8 Quantization Integration: Explicitly enables and tests the async TP fusions for models utilizing FP8 (Float8) data types, building upon previous work to extend async tensor parallelism to quantized models.
  • Comprehensive Testing: Added new test models (TestScaledMMRSModel, TestAGScaledMMModel) and updated existing test infrastructure to validate the new _scaled_mm fusions, including testing with a real-world FP8 quantized model (RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for asynchronous tensor parallelism for scaled matrix multiplication (scaled_mm), which is particularly useful for FP8 quantized models. The changes include adding new fusion patterns and extending the test suite. The implementation looks good overall, but I have identified a few areas for improvement regarding code duplication in tests, a potential reduction in test coverage, and a potential correctness issue in one of the new fusion patterns.

Signed-off-by: cascade812 <[email protected]>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall! If we can make it work with cutlass_scaled_mm that would be perfect

Copy link

mergify bot commented Jul 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cascade812.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 14, 2025
Signed-off-by: cascade812 <[email protected]>
Signed-off-by: cascade812 <[email protected]>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Could we add an end-to-end test for cutlass_scaled_mm as well?

@cascade812
Copy link
Contributor Author

Looks good! Could we add an end-to-end test for cutlass_scaled_mm as well?

Sure! I've added the test using the RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 model to test_async_tp_pass_correctness in test_async_tp.py.

def replacement(input: torch.Tensor, mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor) -> torch.Tensor:
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
Copy link
Collaborator

@zou3519 zou3519 Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a reference somewhere that says that torch._scaled_mm + reduce_scatter is equivalent to torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, thanks for the reference and the tests

@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 25, 2025
@zou3519
Copy link
Collaborator

zou3519 commented Jul 29, 2025

@cascade812 can you rebase please? I think the test failures look unrelated

@cascade812
Copy link
Contributor Author

@zou3519 I've merged latest branch, can you pls help merge this PR?

@zou3519 zou3519 merged commit 287f527 into vllm-project:main Jul 30, 2025
66 checks passed
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
vadiklyutiy pushed a commit to CentML/vllm that referenced this pull request Aug 5, 2025
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants