-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[NVIDIA] Support SiluMul + NVFP4 quant fusion #23671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Support SiluMul + NVFP4 quant fusion #23671
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for fusing SiLU+Mul with NVFP4 quantization, which is a valuable performance optimization for models running on NVIDIA GPUs with FP4 support. The changes are well-structured, including a new CUDA kernel, updates to the compilation passes for fusion, and comprehensive tests. The refactoring of the fusion pass and tests to accommodate the new pattern is clean. My review found one potential issue with pointer casting in the CUDA kernel wrapper that could lead to undefined behavior, and I've provided a suggestion to fix it. Overall, this is a solid contribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pointer casts for output_ptr
and sf_out
are unsafe and overly complex, and the subsequent reinterpret_cast
in the kernel launch can be avoided.
static_cast<int64_t*>(output.data_ptr())
is unsafe. Theoutput
tensor is of typetorch.uint8
, so its data buffer is not guaranteed to have the 8-byte alignment required forint64_t*
. This can lead to undefined behavior.- The kernel expects
uint32_t*
for bothout
andSFout
. It's cleaner to cast directly to this type usingreinterpret_cast
.
By casting directly to uint32_t*
when defining output_ptr
and sf_out
, you can simplify the kernel launch call by removing the reinterpret_cast
there.
void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d]
torch::Tensor& output_sf,
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& input_sf) {
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
int32_t m = input.size(0);
int32_t n = input.size(1) / 2;
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = reinterpret_cast<uint32_t*>(output_sf.data_ptr());
auto output_ptr = reinterpret_cast<uint32_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024));
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "act_and_mul_quant_kernel", [&] {
auto input_ptr = reinterpret_cast<scalar_t const*>(input.data_ptr());
VLLM_DISPATCH_BYTE_TYPES(
output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type",
[&] {
vllm::silu_and_cvt_fp16_to_fp4<scalar_t>
<<<grid, block, 0, stream>>>(
m, n, input_ptr, input_sf_ptr,
output_ptr,
sf_out);
});
});
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice and clean, thanks for the refactoring! A few final comments and create an issue for the kernel comments for follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a FUSED_OPs array here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this reference FUSED_OPS
and QUANT_OPS
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this could use ops_in_model_before
(see other tests on how that's checked)
ed4f126
to
d7831c6
Compare
Update: fixed by conflict between
|
d7831c6
to
3968b5b
Compare
Look like it is failed to create tensor on L4:
I got a L4 locally and tried creating tensors and it worked. Is the failure related to the driver or something else for the L4 in CI? cc @ProExpertProg @mgoin
|
Signed-off-by: jindih <[email protected]> fix review comment Signed-off-by: jindih <[email protected]> revise silu+nvfp4q pattern matching part Signed-off-by: jindih <[email protected]>
Signed-off-by: elvischenv <[email protected]>
Signed-off-by: elvischenv <[email protected]>
Signed-off-by: elvischenv <[email protected]>
Signed-off-by: elvischenv <[email protected]>
Head branch was pushed to by a user without write access
3968b5b
to
8b479b2
Compare
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: elvischenv <[email protected]>
…3671)" Fixes vllm-project#23925 This reverts commit 16a45b3.
Signed-off-by: jindih <[email protected]> Signed-off-by: elvischenv <[email protected]> Co-authored-by: jindih <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedic <[email protected]>
Signed-off-by: jindih <[email protected]> Signed-off-by: elvischenv <[email protected]> Co-authored-by: jindih <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedic <[email protected]>
Purpose
Support Silu_Mul + NVFP4 quant fusion(following up #22448).
Add these compilation flags to enable the fusion:
Test Plan && Test Result
Kernel functional:
tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
Fusion unit test:
tests/compile/test_silu_mul_quant_fusion.py
lm_eval && benchmarking:
main:
PR:
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.