Skip to content

Conversation

djmmoss
Copy link
Contributor

@djmmoss djmmoss commented Jul 10, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Currently the block scaled group gemm for SM100 doesn't support enable_expert_parallel. If the feature is enabled the it can result in device side assertions. This change defaults to back to Triton if enable_exper_parallel=True

Test Plan

N/A

Test Result

N/A

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @djmmoss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue where the cutlass block scaled group gemm operation leads to device-side assertions when expert_parallel is enabled. The changes introduce a mechanism to detect the presence of expert_parallel and, if active, disable the problematic Cutlass implementation, ensuring a stable fallback to Triton for MoE operations.

Highlights

  • Bug Fix for Expert Parallelism: Implemented a fix to prevent device-side assertions that occur when cutlass block scaled group gemm is used in conjunction with expert_parallel enabled, particularly on SM100 architectures.
  • Conditional Disablement of Cutlass: Modified the _valid_cutlass_block_scaled_grouped_gemm function to accept an expert_map argument. If expert_map is present (indicating expert_parallel is active), the function now explicitly disables the Cutlass-based GEMM and logs a warning.
  • Fallback to Triton: Ensures that when the Cutlass implementation is disabled due to expert_parallel being active, the system gracefully falls back to a compatible Triton implementation for Mixture of Experts (MoE) operations, maintaining stability and preventing runtime errors.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a device-side assertion by disabling the cutlass_block_scaled_grouped_gemm kernel when expert parallelism is enabled. The implementation is clean and directly addresses the issue. My only recommendation is to add a unit test to verify this new logic and prevent potential regressions.

Comment on lines 568 to 572
if expert_map is not None:
logger.warning(
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is"
" not supported.")
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change correctly disables the kernel when expert parallelism is active. However, the pull request lacks tests to verify this new behavior. Adding a unit test is crucial for bug fixes to prevent regressions and ensure the logic is sound.

A simple unit test could be added to tests/kernels/moe/test_cutlass_moe.py to confirm that _valid_cutlass_block_scaled_grouped_gemm returns False when an expert_map is provided and True otherwise (assuming other conditions pass).

Here's an example test case:

import pytest
import torch
from vllm.model_executor.layers.fused_moe.cutlass_moe import _valid_cutlass_block_scaled_grouped_gemm

def test_valid_cutlass_block_scaled_grouped_gemm_ep_logic():
    # Create tensors that would otherwise be valid for the kernel.
    # Shapes must be multiples of 128 for this kernel.
    N, K = 128, 128
    w1 = torch.empty(1, 2 * N, K, dtype=torch.float8_e4m3fn, device="cuda")
    w2 = torch.empty(1, K, N, dtype=torch.float8_e4m3fn, device="cuda")
    
    # When expert_map is provided, the kernel should be disabled.
    expert_map = torch.tensor([0], device="cuda")
    assert not _valid_cutlass_block_scaled_grouped_gemm(w1, w2, expert_map)
    
    # When expert_map is None, the kernel should be considered valid.
    assert _valid_cutlass_block_scaled_grouped_gemm(w1, w2, None)

)
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)):
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2, expert_map)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function should check a few extra things actually: activation, apply_router_weight_on_input, expert_map and probably inplace

djmmoss added 2 commits July 10, 2025 20:28
Signed-off-by: Duncan Moss <[email protected]>
Signed-off-by: Duncan Moss <[email protected]>
@djmmoss djmmoss requested a review from aarnphm as a code owner July 10, 2025 21:34
@mergify mergify bot added the frontend label Jul 10, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

@mgoin mgoin enabled auto-merge (squash) July 10, 2025 21:36
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 10, 2025
@mgoin mgoin merged commit 5923ab9 into vllm-project:main Jul 11, 2025
105 of 107 checks passed
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
" apply_router_weight_on_input is not supported.")
return False

if inplace:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @djmmoss , why do we need to disable when inplace is True ?

epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants