-
-
Notifications
You must be signed in to change notification settings - Fork 10.3k
[Performance] Reapply Performance improvements in non-blockwise fp8 CUTLASS MoE #23001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…oE (vllm-project#20762) Signed-off-by: ElizaWszola <[email protected]> Co-authored-by: Chih-Chieh-Yang <[email protected]> Signed-off-by: Chih-Chieh-Yang <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request reapplies performance improvements for non-blockwise FP8 CUTLASS MoE by pre-computing stride tensors and introducing a custom CUDA kernel for permutation. The changes are well-structured and align with the performance goals. However, I've identified a critical issue in the new shuffle_rows
CUDA operation within csrc/moe/moe_permute_unpermute_op.cu
. The alignment check is incorrect, which could lead to memory access errors. I've also suggested a refactoring to reduce code duplication in the same function.
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) { | ||
// use slow kernel if num_cols can't be aligned to 128 bits | ||
MOE_DISPATCH(input_tensor.scalar_type(), [&] { | ||
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>( | ||
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()), | ||
dst2src_map.data_ptr<int32_t>(), | ||
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows, | ||
num_dest_rows, num_cols); | ||
}); | ||
} else { | ||
MOE_DISPATCH(input_tensor.scalar_type(), [&] { | ||
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>( | ||
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()), | ||
dst2src_map.data_ptr<int32_t>(), | ||
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows, | ||
num_dest_rows, num_cols); | ||
}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two issues in this block:
-
Critical Bug: The expression
sizeof(input_tensor.scalar_type())
is incorrect for getting the size of the tensor's elements.input_tensor.scalar_type()
returns ac10::ScalarType
enum, andsizeof
on it will return the size of the enum type itself (e.g., 1 or 4 bytes), not the size of the data type it represents. This will lead to an incorrect alignment check, which could cause the fast kernel path to be taken for unaligned inputs, leading to memory access errors or incorrect results. The correct way to get the element size is by usingc10::elementSize(input_tensor.scalar_type())
or, inside theMOE_DISPATCH
macro,sizeof(scalar_t)
. -
Code Duplication: The
MOE_DISPATCH
call is duplicated in theif
andelse
branches. This can be simplified by moving theif/else
logic inside theMOE_DISPATCH
lambda, which improves maintainability and reduces code duplication.
Here is a suggested change that addresses both issues:
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
if (num_cols % (128 / sizeof(scalar_t) / 8)) {
// use slow kernel if num_cols can't be aligned to 128 bits
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
} else {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
}
});
closing in favor of #23045 |
…oE (#20762)
Purpose
Re-apply #20762 with changes needed for the latest main. The bug that was triggered by the original PR was fixed in #21426
Test Plan
Need some suggestions for tests. Might need to run with Maverick
Test Result
(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.