Skip to content

CUDA: Optimize reduce_rows_f32 kernel, leading up to 25x perf improvement on kernel-level and 10% perf increase for Gemma3n #15132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
20 changes: 0 additions & 20 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -412,26 +412,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE
}

// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;

float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}

sum = warp_reduce_sum(sum);

if (col != 0) {
return;
}

dst[row] = norm ? sum / ncols : sum;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_all(int x) {
#ifdef GGML_USE_HIP
Expand Down
13 changes: 11 additions & 2 deletions ggml/src/ggml-cuda/mean.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "mean.cuh"
#include "reduce_rows.cuh"

void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
Expand All @@ -13,7 +14,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);

const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
} else {
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
}
53 changes: 53 additions & 0 deletions ggml/src/ggml-cuda/reduce_rows.cuh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have written this kernel differently. I would have made the CUDA block size a template parameter and increased it as long as it reduces the number of iterations needed (as is done in e.g. mmv.cu/mmvf.cu).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Templating the kernel also crossed our minds (see general PR description). However, templating would have lead to an increased size of the generated binaries and was thus not our preferred option given that it did not yield significant speed-ups in internal tests.

Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "common.cuh"

// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template <bool norm>
static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;

float sum = 0.0f;
const int num_unroll = 8;
float temp[num_unroll];
float sum_temp[num_unroll] = { 0.0f };
for (int i = col; i < ncols;) {
for (int j = 0; j < num_unroll; ++j) {
if (i < ncols) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition would have been that it is faster not to add the inner loop due to this conditional statement. Just to be sure: did you test both versions?

Copy link
Contributor Author

@ORippler ORippler Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intuition would have been that it is faster not to add the inner loop due to this conditional statement. Just to be sure: did you test both versions?

We shared that intuition and, as mentioned in the PR description, one of the first things we tried was hinting the compiler to unroll the outer loop with #pragma unroll. Unfortunately, the compiler did not comply, and we were still seeing a lot of long scoreboard stalls caused by sequential iteration through the for loop (see the following two screenshots).

Screenshot 2025-08-07 at 13 28 59 image

Only by explicitly unrolling the loop did we get the compiler to comply and pre-fetch the data, effectively hiding the memory-latency (see 8 sequential FADDs preceeded by 8 sequential LDGs):
Screenshot 2025-08-07 at 13 28 39

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, the reason the outer loop cannot be unrolled is simply because the number of iterations isn't known at compile time right? The inner loop has a fixed size and can therefore be unrolled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, the reason the outer loop cannot be unrolled is simply because the number of iterations isn't known at compile time right? The inner loop has a fixed size and can therefore be unrolled.

In this case, the only pre-requisite for loop-unrolling followed by instruction reordering is an unaliased pointer, which we declare via __restrict__. nvcc did unroll the loop, but it did not reorder the instructions/batch the LDGs. We manually nudged it into the right direction by unrolling the loop, where the path to optimize becomes clearer to the compiler. Knowing the number of iterations at compile time is another example of such a nudge 😃

temp[j] = x[row * ncols + i];
} else {
temp[j] = 0;
}
i += blockDim.x;
}
for (int j = 0; j < num_unroll; ++j) {
sum_temp[j] += temp[j];
}
}
for (int j = 0; j < num_unroll; ++j) {
sum += sum_temp[j];
}

// sum up partial sums
sum = warp_reduce_sum(sum);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = sum;
}
__syncthreads();
sum = 0.0f;
if (lane_id < (blockDim.x / WARP_SIZE)) {
sum = s_sum[lane_id];
}
sum = warp_reduce_sum(sum);
}

if (col != 0) {
return;
}

dst[row] = norm ? sum / ncols : sum;
}
25 changes: 21 additions & 4 deletions ggml/src/ggml-cuda/sumrows.cu
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
#include "reduce_rows.cuh"
#include "sumrows.cuh"

void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else {
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
}

void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand All @@ -19,8 +27,17 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);

reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
if ((nrows / nsm) < 2) {
// Increase num threads to 512 for small nrows to better hide the latency
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
} else {
// Enough active SMs to hide latency, use smaller blocks to allow better scheduling
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
}
20 changes: 20 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5998,6 +5998,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_mean());
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_acc());
Expand Down Expand Up @@ -6179,6 +6187,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
}

std::vector<std::array<int64_t, 4>> reduce_rows_cases = {
{ 8192, 1, 1, 1 },
{ 8192, 8192, 1, 1 },
{ 128, 8192, 1, 1 },
};

for (auto it: reduce_rows_cases){
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, it));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, it));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
}

return test_cases;
}

Expand Down
Loading