Skip to content

Conversation

ORippler
Copy link
Contributor

@ORippler ORippler commented Sep 1, 2025

This PR optimizes rms_norm_f32 kernel and its fused variants by the following 2 changes:

  1. Addition of fast integer division, and using it within the modulo operator. Integer division by run-time constants can be implemented via a single multiplication + bit-shift, which is significantly faster than the naive implementation. By pre-computing values on host/CPU side, we thus gain a lot of perf, especially when the same division is performed multiple times within a kernel. This technique is already used in our vulkan backend, and can be applied to other cuda kernels which use the same modulo operation multiple times in future.
  2. Adding support for arbitrary values of blockDim.x in the interval [1, 1024]. This allows us to schedule 2 full warps on a SM for small vector-lengths, better hiding latencies there.

Together, this leads to 1-6 % perf gains across the following benched models:

master (b66df9d)
Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes

model size params backend ngl test t/s
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 pp100 @ d100 3545.35 ± 41.07
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 tg100 @ d100 165.46 ± 0.69
gemma3n E4B Q8_0 6.84 GiB 6.87 B CUDA 99 pp100 @ d100 2587.48 ± 76.47
gemma3n E4B Q8_0 6.84 GiB 6.87 B CUDA 99 tg100 @ d100 124.24 ± 0.48
gemma3 12B Q8_0 11.64 GiB 11.77 B CUDA 99 pp100 @ d100 2899.78 ± 18.41
gemma3 12B Q8_0 11.64 GiB 11.77 B CUDA 99 tg100 @ d100 86.09 ± 0.13
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 pp100 @ d100 6281.45 ± 174.32
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 tg100 @ d100 327.53 ± 1.58
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 pp100 @ d100 5155.90 ± 51.79
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 tg100 @ d100 242.80 ± 2.91
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp100 @ d100 3293.58 ± 64.60
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg100 @ d100 242.56 ± 0.72
qwen3moe 30B.A3B Q3_K - Small 12.37 GiB 30.53 B CUDA 99 pp100 @ d100 1862.04 ± 45.24
qwen3moe 30B.A3B Q3_K - Small 12.37 GiB 30.53 B CUDA 99 tg100 @ d100 165.02 ± 0.84

This PR:
Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes

model size params backend ngl test t/s
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 pp100 @ d100 3649.65 ± 30.64
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 tg100 @ d100 175.12 ± 1.08
gemma3n E4B Q8_0 6.84 GiB 6.87 B CUDA 99 pp100 @ d100 2871.36 ± 137.86
gemma3n E4B Q8_0 6.84 GiB 6.87 B CUDA 99 tg100 @ d100 130.39 ± 0.33
gemma3 12B Q8_0 11.64 GiB 11.77 B CUDA 99 pp100 @ d100 3013.22 ± 15.47
gemma3 12B Q8_0 11.64 GiB 11.77 B CUDA 99 tg100 @ d100 89.35 ± 0.12
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 pp100 @ d100 6458.34 ± 26.40
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 tg100 @ d100 330.64 ± 11.34
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 pp100 @ d100 5200.88 ± 38.18
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 tg100 @ d100 253.41 ± 0.61
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 pp100 @ d100 3244.66 ± 102.53
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B CUDA 99 tg100 @ d100 245.18 ± 0.66
qwen3moe 30B.A3B Q3_K - Small 12.37 GiB 30.53 B CUDA 99 pp100 @ d100 1908.57 ± 27.90
qwen3moe 30B.A3B Q3_K - Small 12.37 GiB 30.53 B CUDA 99 tg100 @ d100 170.90 ± 0.80

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 1, 2025
Fastdiv is much faster way to do integer division, which was identified
as bottleneck in rms_norm_f32
This makes us more flexible in selecting the optimal threads w.r.t
paralellizing across a col vs. launch-overheads of threads and mio
throttles
@ORippler ORippler changed the title CUDA: Optimize rms_norm_f32 kernel and its fused variants CUDA: Optimize rms_norm_f32 kernel and its fused variants, giving 1-6% perf E2E Sep 2, 2025
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR, this is potentially useful for other kernels as well. I'll read the paper, test the performance on my hardware as well, and then get back to you.

FYI, you can use scripts/compare-commits.sh to automatically create a table comparing the performance of 2 commits (ot manually use llama-bench and sctipts/compare-llama-bench.py).

ORippler added a commit to ORippler/llama.cpp that referenced this pull request Sep 2, 2025
@ORippler
Copy link
Contributor Author

ORippler commented Sep 2, 2025

Thank you for the PR, this is potentially useful for other kernels as well.

Agreed, one kernel that came to my mind was k_bin_bcast.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some concerns about usability, I think it's very easy to accidentally use the wrong values and get incorrect results. Do you think it would make sense to package the original value, mp, and L as uint3? If you were to pass uint3 to fastdiv and fastmodulo that would at least ensure that the correct values are being used together. (I think the CUDA compiler is smart enough not to copy unused kernel arguments, so this shouldn't increase register use.)

taronaeo pushed a commit that referenced this pull request Sep 2, 2025
Co-authored-by: Johannes Gäßler <[email protected]>
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preferred interface would be return the current uint3 result of init_fastmodulo_values as init_fastdiv_values and to pass it both for fastdiv and fastodulo. That's why I was talking about whether or not the compiler would optimize out unused values, then it would optimize out d if only fastdiv is used.

The compiler seems to reliably optimize away the unused .z component in
the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3
@ORippler
Copy link
Contributor Author

ORippler commented Sep 3, 2025

My preferred interface would be return the current uint3 result of init_fastmodulo_values as init_fastdiv_values and to pass it both for fastdiv and fastodulo. That's why I was talking about whether or not the compiler would optimize out unused values, then it would optimize out d if only fastdiv is used.

It's actually optimized away in the SASS, see https://godbolt.org/z/rx8KPrKr3 (godbolt is my go-to tool when trying to pick at the compiler). While I try to not rely on the compiler too much (for example, PTX has 2 loads for uint3 and a vectorized load for uint2 in the sent godbolt link), I'm fine with doing so in this case.

ORippler and others added 2 commits September 3, 2025 17:35
Co-authored-by: Johannes Gäßler <[email protected]>
As suggest by @JohannesGaessler, this increases clarity of the intended
use
@JohannesGaessler
Copy link
Collaborator

Thank you, I'll merge this once the CI is done. The ggml matrix multiplication / FlashAttention kernels need to do some integer divisions to determine which data to work on, but register pressure is also a major concern. I'll investigate the use of fastdiv for those cases (unless you want to do it).

@JohannesGaessler
Copy link
Collaborator

For the cases relevant for GEMM/FA I think it will be possible to store mp with 16 bits, L with 4 bits, and the original value with 12 bits. That way you would still need only a single register to store the data.

@JohannesGaessler JohannesGaessler merged commit 661ae31 into ggml-org:master Sep 3, 2025
48 checks passed
@JohannesGaessler
Copy link
Collaborator

The benefit seems to be larger for fast GPUs:

GPU Model Microbatch size Test t/s b6341 t/s 8bde72b Speedup
P40 llama 8B Q4_0 1 pp512 56.22 56.32 1.00
P40 llama 8B Q4_0 2 pp512 111.05 111.59 1.00
P40 llama 8B Q4_0 4 pp512 159.18 159.70 1.00
P40 llama 8B Q4_0 8 pp512 199.42 201.18 1.01
P40 llama 8B Q4_0 16 pp512 473.35 475.96 1.01
P40 llama 8B Q4_0 32 pp512 577.08 578.60 1.00
P40 llama 8B Q4_0 64 pp512 779.56 782.74 1.00
P40 llama 8B Q4_0 128 pp512 901.20 900.61 1.00
P40 llama 8B Q4_0 256 pp512 989.08 990.53 1.00
P40 llama 8B Q4_0 512 pp512 1033.35 1034.60 1.00
P40 qwen3moe 30B.A3B Q3_K_S 1 pp512 45.52 46.18 1.01
P40 qwen3moe 30B.A3B Q3_K_S 2 pp512 68.14 68.68 1.01
P40 qwen3moe 30B.A3B Q3_K_S 4 pp512 103.04 103.29 1.00
P40 qwen3moe 30B.A3B Q3_K_S 8 pp512 155.04 155.08 1.00
P40 qwen3moe 30B.A3B Q3_K_S 16 pp512 262.75 263.64 1.00
P40 qwen3moe 30B.A3B Q3_K_S 32 pp512 334.52 336.49 1.01
P40 qwen3moe 30B.A3B Q3_K_S 64 pp512 367.55 368.87 1.00
P40 qwen3moe 30B.A3B Q3_K_S 128 pp512 562.84 565.18 1.00
P40 qwen3moe 30B.A3B Q3_K_S 256 pp512 782.93 783.50 1.00
P40 qwen3moe 30B.A3B Q3_K_S 512 pp512 1003.04 1005.36 1.00
RTX 4090 llama 8B Q4_0 1 pp512 188.47 189.43 1.01
RTX 4090 llama 8B Q4_0 2 pp512 334.45 335.15 1.00
RTX 4090 llama 8B Q4_0 4 pp512 653.64 653.80 1.00
RTX 4090 llama 8B Q4_0 8 pp512 1085.25 1085.87 1.00
RTX 4090 llama 8B Q4_0 16 pp512 1842.49 1847.15 1.00
RTX 4090 llama 8B Q4_0 32 pp512 3332.12 3351.60 1.01
RTX 4090 llama 8B Q4_0 64 pp512 5814.07 5821.98 1.00
RTX 4090 llama 8B Q4_0 128 pp512 8733.03 8777.78 1.01
RTX 4090 llama 8B Q4_0 256 pp512 11730.38 11773.61 1.00
RTX 4090 llama 8B Q4_0 512 pp512 13258.03 13247.95 1.00
RTX 4090 qwen3moe 30B.A3B Q4_0 1 pp512 224.25 229.86 1.03
RTX 4090 qwen3moe 30B.A3B Q4_0 2 pp512 216.71 219.06 1.01
RTX 4090 qwen3moe 30B.A3B Q4_0 4 pp512 374.11 377.80 1.01
RTX 4090 qwen3moe 30B.A3B Q4_0 8 pp512 611.66 614.01 1.00
RTX 4090 qwen3moe 30B.A3B Q4_0 16 pp512 976.81 982.17 1.01
RTX 4090 qwen3moe 30B.A3B Q4_0 32 pp512 1557.01 1567.95 1.01
RTX 4090 qwen3moe 30B.A3B Q4_0 64 pp512 2546.91 2548.34 1.00
RTX 4090 qwen3moe 30B.A3B Q4_0 128 pp512 3464.85 3469.72 1.00
RTX 4090 qwen3moe 30B.A3B Q4_0 256 pp512 5594.30 5582.37 1.00
RTX 4090 qwen3moe 30B.A3B Q4_0 512 pp512 7856.42 7843.87 1.00

By packing constants to be used together into a struct, we are less
likely to make errors.
`modulo_consts` is more fitting/descriptive
@ORippler
Copy link
Contributor Author

ORippler commented Sep 4, 2025

The benefit seems to be larger for fast GPUs:

Indeed, but only because for them latency-limited kernels such as rms_norm_f32 take up more relative share of the overall workload.
image

Thank you, I'll merge this once the CI is done. The ggml matrix multiplication / FlashAttention kernels need to do some integer divisions to determine which data to work on, but register pressure is also a major concern. I'll investigate the use of fastdiv for those cases (unless you want to do it).

Thanks for merging! I still think the fastdiv pattern is promising and should be applied to more kernels, as all GPUs will benefit from it perf-wise. I'll personally be available to work on this ~middle of October earliest as I'm going to take a leave, but will gladly assist in any way I can until my leave.

Regarding kernel pressure: you have to additionally factor in the registers required by the naive integer division to handle its intermediates. Playing around in godbolt (so static analysis only) I couldn't find out a use-case where fastdiv used more registers for the SMs I compiled SASS for.

@ORippler ORippler deleted the osimons/optimize_fused_rms_norm_f32 branch September 4, 2025 16:09
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Sep 4, 2025
…upport

* origin/master: (72 commits)
metal : Add template specialization for mul_mm_id w/ ne20 == 10 (ggml-org#15799)
llama : set n_outputs to 1 to avoid 0 outputs mean-pooling (ggml-org#15791)
CANN: Refactor ND to NZ workspace to be per-device (ggml-org#15763)
server: add exceed_context_size_error type (ggml-org#15780)
Document the new max GPU layers default in help (ggml-org#15771)
ggml: add ops for WAN video model (cuda && cpu) (ggml-org#15669)
CANN: Fix precision issue on 310I DUO multi-devices (ggml-org#15784)
opencl: add hs=40 to FA (ggml-org#15758)
CANN: fix acl_rstd allocation size in ggml_cann_rms_norm (ggml-org#15760)
vulkan: fix mmv subgroup16 selection (ggml-org#15775)
vulkan: don't use std::string in load_shaders, to improve compile time (ggml-org#15724)
vulkan : update ggml_vk_instance_validation_ext_available (ggml-org#15666)
ggml vulkan: add hardsigmoid and hardswish operations (ggml-org#15762)
CUDA: Optimize `rms_norm_f32` kernel and its fused variants, giving 1-6% perf E2E (ggml-org#15715)
model-conversion : fix pyright errors (ggml-org#15770)
sampling : optimize dist sampler (ggml-org#15704)
llama : fix incorrect model type for Gemma 270M (ggml-org#15764)
model-conversion : remove hardcoded /bin/bash shebangs [no ci] (ggml-org#15765)
CANN: Add RoPE contiguous check for 310I DUP device (ggml-org#15735)
ggml-cpu : optimize RVV kernels (ggml-org#15720)
...
@JohannesGaessler
Copy link
Collaborator

Using fastdiv for mul_mat_vec_q and quantize_q8_1 shaves off ~1% for token generation on RTX 3090/4090: #15802

walidbr pushed a commit to walidbr/llama.cpp that referenced this pull request Sep 7, 2025
walidbr pushed a commit to walidbr/llama.cpp that referenced this pull request Sep 7, 2025
…-6% perf E2E (ggml-org#15715)

* Add fastdiv, use it in modulo and use modulo in rms_norm_f32

Fastdiv is much faster way to do integer division, which was identified
as bottleneck in rms_norm_f32

* Support more `block_size` values in `rms_norm_f32`

This makes us more flexible in selecting the optimal threads w.r.t
paralellizing across a col vs. launch-overheads of threads and mio
throttles

* Update ggml/src/ggml-cuda/common.cuh

Co-authored-by: Johannes Gäßler <[email protected]>

* Replace modulo with fastmodulo in `rms_norm_f32`

* Use `BinPackArguments=true` for formating function calls

Will file a separate PR to adjust .clang-format file

* Update ggml/src/ggml-cuda/common.cuh

Co-authored-by: Johannes Gäßler <[email protected]>

* Use uint3 for both `fastdiv` and `fastmodulo`

The compiler seems to reliably optimize away the unused .z component in
the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3

* More constrained type declarations

Co-authored-by: Johannes Gäßler <[email protected]>

* Rename fastdiv and fastmodulo variables to shared variable name

As suggest by JohannesGaessler, this increases clarity of the intended
use

* Pack fastdiv/fastmodulo constants into uint2/uint3 objects

By packing constants to be used together into a struct, we are less
likely to make errors.

* Rename function parameter of fastmodulo

`modulo_consts` is more fitting/descriptive

---------

Co-authored-by: Johannes Gäßler <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants