-
Notifications
You must be signed in to change notification settings - Fork 12.6k
CUDA: Optimize reduce_rows_f32
kernel, leading up to 25x perf improvement on kernel-level and 10% perf increase for Gemma3n
#15132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
3deb3b1
c270ffe
ece608a
9070af8
80de672
8e04242
8fc2c03
9296d1f
a6fe4dd
4a1c5bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#include "common.cuh" | ||
|
||
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true) | ||
template <bool norm> | ||
static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { | ||
const int row = blockIdx.x; | ||
const int col = threadIdx.x; | ||
|
||
float sum = 0.0f; | ||
const int num_unroll = 8; | ||
float temp[num_unroll]; | ||
float sum_temp[num_unroll] = { 0.0f }; | ||
for (int i = col; i < ncols;) { | ||
for (int j = 0; j < num_unroll; ++j) { | ||
if (i < ncols) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My intuition would have been that it is faster not to add the inner loop due to this conditional statement. Just to be sure: did you test both versions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We shared that intuition and, as mentioned in the PR description, one of the first things we tried was hinting the compiler to unroll the outer loop with ![]() ![]() Only by explicitly unrolling the loop did we get the compiler to comply and pre-fetch the data, effectively hiding the memory-latency (see 8 sequential There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, the reason the outer loop cannot be unrolled is simply because the number of iterations isn't known at compile time right? The inner loop has a fixed size and can therefore be unrolled. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In this case, the only pre-requisite for loop-unrolling followed by instruction reordering is an unaliased pointer, which we declare via |
||
temp[j] = x[row * ncols + i]; | ||
} else { | ||
temp[j] = 0; | ||
} | ||
i += blockDim.x; | ||
} | ||
for (int j = 0; j < num_unroll; ++j) { | ||
sum_temp[j] += temp[j]; | ||
} | ||
} | ||
for (int j = 0; j < num_unroll; ++j) { | ||
sum += sum_temp[j]; | ||
} | ||
|
||
// sum up partial sums | ||
sum = warp_reduce_sum(sum); | ||
if (blockDim.x > WARP_SIZE) { | ||
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); | ||
__shared__ float s_sum[32]; | ||
const int warp_id = threadIdx.x / WARP_SIZE; | ||
const int lane_id = threadIdx.x % WARP_SIZE; | ||
if (lane_id == 0) { | ||
s_sum[warp_id] = sum; | ||
} | ||
__syncthreads(); | ||
sum = 0.0f; | ||
if (lane_id < (blockDim.x / WARP_SIZE)) { | ||
sum = s_sum[lane_id]; | ||
} | ||
sum = warp_reduce_sum(sum); | ||
} | ||
|
||
if (col != 0) { | ||
return; | ||
} | ||
|
||
dst[row] = norm ? sum / ncols : sum; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have written this kernel differently. I would have made the CUDA block size a template parameter and increased it as long as it reduces the number of iterations needed (as is done in e.g.
mmv.cu
/mmvf.cu
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Templating the kernel also crossed our minds (see general PR description). However, templating would have lead to an increased size of the generated binaries and was thus not our preferred option given that it did not yield significant speed-ups in internal tests.