Skip to content

Conversation

gshtras
Copy link
Collaborator

@gshtras gshtras commented Mar 7, 2025

Optimization ported over from ROCm/vllm.
Applying weight padding for MoE.
The principle and rationale is similar to FP8 padding in #13231 except here it's for the half precision types.
The optimization is more experimental and does not apply to any MoE model, therefore is disabled by default.
Expanded unit tests to cover the padding case.

Performance wise, up to 10% improvement in latency numbers can be observed with this feature enabled on mistralai/Mixtral-8x22B-Instruct-v0.1 in the following configuration: bs=64;in=256;out=256;tp=8

Copy link

github-actions bot commented Mar 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Mar 7, 2025

QQ - would we ever not want to do this if we are on ROCm for MoE?

@gshtras
Copy link
Collaborator Author

gshtras commented Mar 8, 2025

QQ - would we ever not want to do this if we are on ROCm for MoE?

It has been mostly tested for Mixtral, other MoE models, especially those with custom MoE implementation may fail due to improper padding handling

@mgoin
Copy link
Member

mgoin commented Mar 10, 2025

does not apply to any MoE model

I think this feature should be improved so it generally satisfies the FusedMoE interface. This seems like a footgun if it will fail on other common MoEs than just Mixtral. Could you give an example of a custom MoE impl that would fail with this?

@divakar-amd
Copy link
Contributor

Hi, this feature should work for any model which extends the FusedMoe class. However, if you are only importing the fused_moe kernel to plug it into a custom layer, then it would require some caution.
Here's an example to elaborate the same -In this PR ( #8518 ) we fixed the above for Dbrx model to exted the eniter FusedMoe class/layer instead of just importing the fused_moe kernel and then defining its own layer. That allowed padding to also work for Dbrx

@charlifu
Copy link
Contributor

charlifu commented Mar 11, 2025

QQ - would we ever not want to do this if we are on ROCm for MoE?

We could do the same condition check just like fp8 padding:

 if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
                and weight.stride(-1) == 1
                and (weight.stride(-2) * weight.element_size()) % 512 == 0):

@charlifu
Copy link
Contributor

charlifu commented Mar 11, 2025

Hi, this feature should work for any model which extends the FusedMoe class. However, if you are only importing the fused_moe kernel to plug it into a custom layer, then it would require some caution. Here's an example to elaborate the same -In this PR ( #8518 ) we fixed the above for Dbrx model to exted the eniter FusedMoe class/layer instead of just importing the fused_moe kernel and then defining its own layer. That allowed padding to also work for Dbrx

There is a way to avoid this. We can also pad the weight tensor and do a slice operation on the weight, just like what we did in the fp8 padding PR #13231:

weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]

If we do so, there is no need to have the padding_size in the fuse_moe.py, but we have to remove the requirement of weight has to be contiguous.

@charlifu charlifu force-pushed the moe_padding_upstream branch from bddc6c3 to fa2b8d1 Compare March 12, 2025 15:50
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call maybe_pad_weight?

"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not enabled by default?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It used to be enabled by default.

Copy link

mergify bot commented Mar 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 20, 2025
@mergify mergify bot removed the needs-rebase label Mar 20, 2025
Signed-off-by: charlifu <[email protected]>
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) March 24, 2025 21:05
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 24, 2025
@robertgshaw2-redhat robertgshaw2-redhat merged commit f533b58 into vllm-project:main Mar 24, 2025
48 checks passed
erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: charlifu <[email protected]>
Co-authored-by: charlifu <[email protected]>
wrmedford pushed a commit to wrmedford/vllm that referenced this pull request Mar 26, 2025
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: charlifu <[email protected]>
Co-authored-by: charlifu <[email protected]>
Signed-off-by: Wes Medford <[email protected]>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: charlifu <[email protected]>
Co-authored-by: charlifu <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
@gshtras gshtras deleted the moe_padding_upstream branch April 7, 2025 14:59
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: charlifu <[email protected]>
Co-authored-by: charlifu <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: charlifu <[email protected]>
Co-authored-by: charlifu <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: charlifu <[email protected]>
Co-authored-by: charlifu <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants