-
Notifications
You must be signed in to change notification settings - Fork 13.2k
CUDA: Optimize rms_norm_f32
kernel and its fused variants, giving 1-6% perf E2E
#15715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: Optimize rms_norm_f32
kernel and its fused variants, giving 1-6% perf E2E
#15715
Conversation
Fastdiv is much faster way to do integer division, which was identified as bottleneck in rms_norm_f32
This makes us more flexible in selecting the optimal threads w.r.t paralellizing across a col vs. launch-overheads of threads and mio throttles
rms_norm_f32
kernel and its fused variantsrms_norm_f32
kernel and its fused variants, giving 1-6% perf E2E
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR, this is potentially useful for other kernels as well. I'll read the paper, test the performance on my hardware as well, and then get back to you.
FYI, you can use scripts/compare-commits.sh
to automatically create a table comparing the performance of 2 commits (ot manually use llama-bench
and sctipts/compare-llama-bench.py
).
Co-authored-by: Johannes Gäßler <[email protected]>
Will file a separate PR to adjust .clang-format file
This seems to correspond with what we want to do, see [here](ggml-org#15715 (comment)) and [clang-format docs](https://clang.llvm.org/docs/ClangFormatStyleOptions.html#binpackarguments)
Agreed, one kernel that came to my mind was |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some concerns about usability, I think it's very easy to accidentally use the wrong values and get incorrect results. Do you think it would make sense to package the original value, mp
, and L
as uint3
? If you were to pass uint3
to fastdiv
and fastmodulo
that would at least ensure that the correct values are being used together. (I think the CUDA compiler is smart enough not to copy unused kernel arguments, so this shouldn't increase register use.)
This seems to correspond with what we want to do, see [here](#15715 (comment)) and [clang-format docs](https://clang.llvm.org/docs/ClangFormatStyleOptions.html#binpackarguments)
Co-authored-by: Johannes Gäßler <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My preferred interface would be return the current uint3
result of init_fastmodulo_values
as init_fastdiv_values
and to pass it both for fastdiv
and fastodulo
. That's why I was talking about whether or not the compiler would optimize out unused values, then it would optimize out d
if only fastdiv
is used.
The compiler seems to reliably optimize away the unused .z component in the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3
It's actually optimized away in the SASS, see https://godbolt.org/z/rx8KPrKr3 (godbolt is my go-to tool when trying to pick at the compiler). While I try to not rely on the compiler too much (for example, PTX has 2 loads for uint3 and a vectorized load for uint2 in the sent godbolt link), I'm fine with doing so in this case. |
Co-authored-by: Johannes Gäßler <[email protected]>
As suggest by @JohannesGaessler, this increases clarity of the intended use
Thank you, I'll merge this once the CI is done. The ggml matrix multiplication / FlashAttention kernels need to do some integer divisions to determine which data to work on, but register pressure is also a major concern. I'll investigate the use of |
For the cases relevant for GEMM/FA I think it will be possible to store |
The benefit seems to be larger for fast GPUs:
|
By packing constants to be used together into a struct, we are less likely to make errors.
`modulo_consts` is more fitting/descriptive
Indeed, but only because for them latency-limited kernels such as
Thanks for merging! I still think the Regarding kernel pressure: you have to additionally factor in the registers required by the naive integer division to handle its intermediates. Playing around in godbolt (so static analysis only) I couldn't find out a use-case where fastdiv used more registers for the SMs I compiled SASS for. |
…upport * origin/master: (72 commits) metal : Add template specialization for mul_mm_id w/ ne20 == 10 (ggml-org#15799) llama : set n_outputs to 1 to avoid 0 outputs mean-pooling (ggml-org#15791) CANN: Refactor ND to NZ workspace to be per-device (ggml-org#15763) server: add exceed_context_size_error type (ggml-org#15780) Document the new max GPU layers default in help (ggml-org#15771) ggml: add ops for WAN video model (cuda && cpu) (ggml-org#15669) CANN: Fix precision issue on 310I DUO multi-devices (ggml-org#15784) opencl: add hs=40 to FA (ggml-org#15758) CANN: fix acl_rstd allocation size in ggml_cann_rms_norm (ggml-org#15760) vulkan: fix mmv subgroup16 selection (ggml-org#15775) vulkan: don't use std::string in load_shaders, to improve compile time (ggml-org#15724) vulkan : update ggml_vk_instance_validation_ext_available (ggml-org#15666) ggml vulkan: add hardsigmoid and hardswish operations (ggml-org#15762) CUDA: Optimize `rms_norm_f32` kernel and its fused variants, giving 1-6% perf E2E (ggml-org#15715) model-conversion : fix pyright errors (ggml-org#15770) sampling : optimize dist sampler (ggml-org#15704) llama : fix incorrect model type for Gemma 270M (ggml-org#15764) model-conversion : remove hardcoded /bin/bash shebangs [no ci] (ggml-org#15765) CANN: Add RoPE contiguous check for 310I DUP device (ggml-org#15735) ggml-cpu : optimize RVV kernels (ggml-org#15720) ...
Using |
…g#15744) This seems to correspond with what we want to do, see [here](ggml-org#15715 (comment)) and [clang-format docs](https://clang.llvm.org/docs/ClangFormatStyleOptions.html#binpackarguments)
…-6% perf E2E (ggml-org#15715) * Add fastdiv, use it in modulo and use modulo in rms_norm_f32 Fastdiv is much faster way to do integer division, which was identified as bottleneck in rms_norm_f32 * Support more `block_size` values in `rms_norm_f32` This makes us more flexible in selecting the optimal threads w.r.t paralellizing across a col vs. launch-overheads of threads and mio throttles * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Johannes Gäßler <[email protected]> * Replace modulo with fastmodulo in `rms_norm_f32` * Use `BinPackArguments=true` for formating function calls Will file a separate PR to adjust .clang-format file * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Johannes Gäßler <[email protected]> * Use uint3 for both `fastdiv` and `fastmodulo` The compiler seems to reliably optimize away the unused .z component in the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3 * More constrained type declarations Co-authored-by: Johannes Gäßler <[email protected]> * Rename fastdiv and fastmodulo variables to shared variable name As suggest by JohannesGaessler, this increases clarity of the intended use * Pack fastdiv/fastmodulo constants into uint2/uint3 objects By packing constants to be used together into a struct, we are less likely to make errors. * Rename function parameter of fastmodulo `modulo_consts` is more fitting/descriptive --------- Co-authored-by: Johannes Gäßler <[email protected]>
This PR optimizes
rms_norm_f32
kernel and its fused variants by the following 2 changes:blockDim.x
in the interval[1, 1024]
. This allows us to schedule 2 full warps on a SM for small vector-lengths, better hiding latencies there.Together, this leads to 1-6 % perf gains across the following benched models:
master (b66df9d)
Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
This PR:
Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes