Skip to content

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Sep 10, 2025

Purpose

Enable DP/EP for gpt-oss model via DeepEPHightThroughput and triton_kernel's matmul_ogs

Test Plan

Run gpt-oss eval from link

Test Result

<style type="text/css"></style>

GPT-OSS-120B  
   
Reasoning Effort GPQA
Low 0.66
Medium 0.72
High 0.8

Benchmark

Please find benchmark results here
TLDR - Good TTFT numbers. Bad TPOT numbers as there is no cuda graphs for decode.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as draft September 10, 2025 13:36
@mergify mergify bot added ci/build gpt-oss Related to GPT-OSS models rocm Related to AMD ROCm labels Sep 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Triton matmul-ogs kernels for GPT-OSS models with Data Parallelism and Expert Parallelism, specifically for mxfp4 quantization. This is done by adding a new OAITritonExperts class. The PR also includes substantial changes to dependency management, removing pinned versions of PyTorch and related packages across various requirement files. My review identified some leftover debugging print statements that should be removed before merging.

@varun-sundar-rabindranath
Copy link
Contributor Author

Status:

command : CUDA_LAUNCH_BLOCKING=1 VLLM_ALL2ALL_BACKEND=deepep_high_throughput canhazgpu run -g2 -- vllm serve openai/gpt-oss-120b --tensor_parallel_size 1 --data-parallel-size 2 --enable-expert-parallel --no-enable-prefix-caching

getting an IMA during model startup

...
(EngineCore_0 pid=1309549)     return self._do_fused_experts(
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 575, in _do_fused_experts
(EngineCore_0 pid=1309549)     self.fused_experts.apply(
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py", line 336, in apply
(EngineCore_0 pid=1309549)     experts_output = triton_kernel_fused_experts(
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py", line 171, in triton_kernel_fused_experts
(EngineCore_0 pid=1309549)     intermediate_cache1 = matmul_ogs(
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/vllm-test/lib/python3.10/site-packages/triton_kernels/matmul_ogs.py", line 550, in matmul_ogs
(EngineCore_0 pid=1309549)     (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/.deps/triton/python/triton/runtime/jit.py", line 420, in <lambda>
(EngineCore_0 pid=1309549)     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/.deps/triton/python/triton/runtime/jit.py", line 742, in run
(EngineCore_0 pid=1309549)     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/.deps/triton/python/triton/backends/nvidia/driver.py", line 717, in __call__
(EngineCore_0 pid=1309549)     self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
(EngineCore_0 pid=1309549) RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
terminate called after throwing an instance of 'c10::AcceleratorError'
  what():  CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


@varun-sundar-rabindranath
Copy link
Contributor Author

Status:

command : CUDA_LAUNCH_BLOCKING=1 VLLM_ALL2ALL_BACKEND=deepep_high_throughput canhazgpu run -g2 -- vllm serve openai/gpt-oss-120b --tensor_parallel_size 1 --data-parallel-size 2 --enable-expert-parallel --no-enable-prefix-caching

getting an IMA during model startup

...
(EngineCore_0 pid=1309549)     return self._do_fused_experts(
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/vllm/model_executor/layers/fused_moe/modular_kernel.py", line 575, in _do_fused_experts
(EngineCore_0 pid=1309549)     self.fused_experts.apply(
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py", line 336, in apply
(EngineCore_0 pid=1309549)     experts_output = triton_kernel_fused_experts(
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py", line 171, in triton_kernel_fused_experts
(EngineCore_0 pid=1309549)     intermediate_cache1 = matmul_ogs(
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/vllm-test/lib/python3.10/site-packages/triton_kernels/matmul_ogs.py", line 550, in matmul_ogs
(EngineCore_0 pid=1309549)     (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/.deps/triton/python/triton/runtime/jit.py", line 420, in <lambda>
(EngineCore_0 pid=1309549)     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/.deps/triton/python/triton/runtime/jit.py", line 742, in run
(EngineCore_0 pid=1309549)     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
(EngineCore_0 pid=1309549)   File "/home/varun/code/vllm/.deps/triton/python/triton/backends/nvidia/driver.py", line 717, in __call__
(EngineCore_0 pid=1309549)     self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
(EngineCore_0 pid=1309549) RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
terminate called after throwing an instance of 'c10::AcceleratorError'
  what():  CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

commit e1e55e09d68eaab088a0433a2328f5eb8675ff4d fixes this.

@varun-sundar-rabindranath
Copy link
Contributor Author

Update. It appears DeepEP kernels dont support the gpt-oss hidden-size (2880). I have a PR here in the DeepEP repo to fix this deepseek-ai/DeepEP#408 . I see correctness with this.

@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as ready for review September 17, 2025 19:10
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/gptoss-triton branch 2 times, most recently from 7bb98dc to 584fdf2 Compare September 17, 2025 19:28
@zyongye
Copy link
Member

zyongye commented Sep 17, 2025

You can add padding here so we don't need to change DeepEP side.

if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, get_mxfp4_backend)
current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)

Copy link

mergify bot commented Sep 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@varun-sundar-rabindranath
Copy link
Contributor Author

You can add padding here so we don't need to change DeepEP side.

if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, get_mxfp4_backend)
current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)

Nice Idea 🙌 I have updated the code to add this padding. PTAL! Thanks 👍

Comment on lines 33 to 37
if hidden_size_bytes % 512 == 0:
return hidden_size

hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
return hidden_size_bytes // dtype.itemsize
Copy link
Contributor

@bnellnm bnellnm Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be replaced with cdiv(hidden_size_bytes, xfer_atom_size)?

Copy link

mergify bot commented Sep 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Sep 19, 2025
@tlrmchlsmth tlrmchlsmth added this to the v0.10.3 milestone Sep 22, 2025
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
moe_in_dtype = params_dtype
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a better way to infer the activation datatype entering the MoE layers. At the moment, all of the MoE code interprets model_config.dtype as the activation datatype as can be seen in the construction of FusedMoEConfig object. I think is very brittle and prone to errors.

cc @mgoin @bnellnm is there a better way that you know of ? Thanks.

Copy link
Contributor

@bnellnm bnellnm Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I wrote this bit of code. I don't know of a better way to get the type but I agree that it would be good to make this more robust. The else clause is literally only for one test. I wouldn't worry too much about whether or not that is brittle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue I think is it really depends on the output dtype of the layer before the MoE layer (usually attention). If that layer chooses to quantize the activations, then this would be wrong 😟

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have little clue beyond this

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 22, 2025
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
moe_in_dtype = params_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have little clue beyond this

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Sep 22, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
@vllm-bot vllm-bot merged commit e8db44f into vllm-project:main Sep 23, 2025
48 of 50 checks passed
DarkLight1337 pushed a commit that referenced this pull request Sep 23, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…project#24588)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…project#24588)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: charlifu <[email protected]>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
gjc0824 pushed a commit to gjc0824/vllm that referenced this pull request Oct 10, 2025
…project#24588)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: gaojc <[email protected]>
gjc0824 pushed a commit to gjc0824/vllm that referenced this pull request Oct 10, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…project#24588)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
…project#24588)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants