Skip to content

Commit 7414652

Browse files
committed
Pack fastdiv/fastmodulo constants into uint2/uint3 objects
By packing constants to be used together into a struct, we are less likely to make errors.
1 parent 48afab4 commit 7414652

File tree

2 files changed

+69
-95
lines changed

2 files changed

+69
-95
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -569,25 +569,33 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
569569
// and a shift:
570570
//
571571
// n/d = (mulhi(n, mp) + n) >> L;
572-
static void init_fastdiv_values(uint32_t d, uint32_t & mp, uint32_t & L) {
572+
static const uint2 init_fastdiv_values(uint32_t d) {
573573
// compute L = ceil(log2(d));
574-
L = 0;
574+
uint32_t L = 0;
575575
while (L < 32 && (uint32_t{ 1 } << L) < d) {
576576
L++;
577577
}
578578

579-
mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
579+
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
580+
return make_uint2(mp, L);
580581
}
581582

582-
static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, uint32_t mp, uint32_t L) {
583+
static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint2 div_consts) {
583584
// Compute high 32 bits of n * mp
584-
const uint32_t hi = __umulhi(n, mp);
585+
const uint32_t hi = __umulhi(n, div_consts.x);
585586
// Apply the formula
586-
return (hi + n) >> L;
587+
return (hi + n) >> div_consts.y;
587588
}
588589

589-
static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, uint32_t divisor, int mp, uint32_t L) {
590-
return n - fastdiv(n, mp, L) * divisor;
590+
static const uint3 init_fastmodulo_values(uint32_t d) {
591+
// uint3 contains <mp, L, divisor> in <x, y, z>
592+
const uint2 fastdiv = init_fastdiv_values(d);
593+
return make_uint3(fastdiv.x, fastdiv.y, d);
594+
}
595+
596+
static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 div_consts_divisor) {
597+
// expects div_consts_divisor to contain <mp, L, divisor> in <x, y, z>
598+
return n - fastdiv(n, make_uint2(div_consts_divisor.x, div_consts_divisor.y)) * div_consts_divisor.z;
591599
}
592600

593601
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);

ggml/src/ggml-cuda/norm.cu

Lines changed: 53 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -105,45 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
105105
}
106106

107107
template <int block_size, bool do_multiply = false, bool do_add = false>
108-
static __global__ void rms_norm_f32(const float * x,
109-
float * dst,
110-
const int ncols,
111-
const int64_t stride_row,
112-
const int64_t stride_channel,
113-
const int64_t stride_sample,
114-
const float eps,
115-
const float * mul = nullptr,
116-
const int64_t mul_stride_row = 0,
117-
const int64_t mul_stride_channel = 0,
118-
const int64_t mul_stride_sample = 0,
119-
const uint32_t mul_ncols = 0,
120-
const uint32_t mul_nrows = 0,
121-
const uint32_t mul_nchannels = 0,
122-
const uint32_t mul_nsamples = 0,
123-
const uint32_t mp_mul_cols = 0,
124-
const uint32_t L_mul_cols = 0,
125-
const uint32_t mp_mul_rows = 0,
126-
const uint32_t L_mul_rows = 0,
127-
const uint32_t mp_mul_channels = 0,
128-
const uint32_t L_mul_channels = 0,
129-
const uint32_t mp_mul_samples = 0,
130-
const uint32_t L_mul_samples = 0,
131-
const float * add = nullptr,
132-
const int64_t add_stride_row = 0,
133-
const int64_t add_stride_channel = 0,
134-
const int64_t add_stride_sample = 0,
135-
const uint32_t add_ncols = 0,
136-
const uint32_t add_nrows = 0,
137-
const uint32_t add_nchannels = 0,
138-
const uint32_t add_nsamples = 0,
139-
const uint32_t mp_add_cols = 0,
140-
const uint32_t L_add_cols = 0,
141-
const uint32_t mp_add_rows = 0,
142-
const uint32_t L_add_rows = 0,
143-
const uint32_t mp_add_channels = 0,
144-
const uint32_t L_add_channels = 0,
145-
const uint32_t mp_add_samples = 0,
146-
const uint32_t L_add_samples = 0) {
108+
static __global__ void rms_norm_f32(const float * x,
109+
float * dst,
110+
const int ncols,
111+
const int64_t stride_row,
112+
const int64_t stride_channel,
113+
const int64_t stride_sample,
114+
const float eps,
115+
const float * mul = nullptr,
116+
const int64_t mul_stride_row = 0,
117+
const int64_t mul_stride_channel = 0,
118+
const int64_t mul_stride_sample = 0,
119+
const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
120+
const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
121+
const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
122+
const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
123+
const float * add = nullptr,
124+
const int64_t add_stride_row = 0,
125+
const int64_t add_stride_channel = 0,
126+
const int64_t add_stride_sample = 0,
127+
const uint3 add_ncols_packed = make_uint3(0, 0, 0),
128+
const uint3 add_nrows_packed = make_uint3(0, 0, 0),
129+
const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
130+
const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
147131
const int nrows = gridDim.x;
148132
const int nchannels = gridDim.y;
149133

@@ -158,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x,
158142
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
159143

160144
if constexpr (do_multiply) {
161-
const uint32_t mul_row = fastmodulo(row, mul_nrows, mp_mul_rows, L_mul_rows);
162-
const uint32_t mul_channel = fastmodulo(channel, mul_nchannels, mp_mul_channels, L_mul_channels);
163-
const uint32_t mul_sample = fastmodulo(sample, mul_nsamples, mp_mul_samples, L_mul_samples);
145+
const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
146+
const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
147+
const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
164148
mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
165149
}
166150

167151
if constexpr (do_add) {
168-
const int add_row = fastmodulo(row, add_nrows, mp_add_rows, L_add_rows);
169-
const int add_channel = fastmodulo(channel, add_nchannels, mp_add_channels, L_add_channels);
170-
const int add_sample = fastmodulo(sample, add_nsamples, mp_add_samples, L_add_samples);
152+
const int add_row = fastmodulo(row, add_nrows_packed);
153+
const int add_channel = fastmodulo(channel, add_nchannels_packed);
154+
const int add_sample = fastmodulo(sample, add_nsamples_packed);
171155
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
172156
}
173157

@@ -201,11 +185,11 @@ static __global__ void rms_norm_f32(const float * x,
201185

202186
for (int col = tid; col < ncols; col += block_size) {
203187
if constexpr (do_multiply && do_add) {
204-
const int mul_col = fastmodulo(col, mul_ncols, mp_mul_cols, L_mul_cols);
205-
const int add_col = fastmodulo(col, add_ncols, mp_add_cols, L_add_cols);
188+
const int mul_col = fastmodulo(col, mul_ncols_packed);
189+
const int add_col = fastmodulo(col, add_ncols_packed);
206190
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
207191
} else if constexpr (do_multiply) {
208-
const int mul_col = fastmodulo(col, mul_ncols, mp_mul_cols, L_mul_cols);
192+
const int mul_col = fastmodulo(col, mul_ncols_packed);
209193
dst[col] = scale * x[col] * mul[mul_col];
210194
} else {
211195
dst[col] = scale * x[col];
@@ -414,63 +398,45 @@ static void rms_norm_mul_f32_cuda(const float * x,
414398
return;
415399
}
416400
if (add == nullptr) {
417-
uint32_t mp_mul_cols, L_mul_cols;
418-
init_fastdiv_values(mul_ncols, mp_mul_cols, L_mul_cols);
419-
uint32_t mp_mul_rows, L_mul_rows;
420-
init_fastdiv_values(mul_nrows, mp_mul_rows, L_mul_rows);
421-
uint32_t mp_mul_channels, L_mul_channels;
422-
init_fastdiv_values(mul_nchannels, mp_mul_channels, L_mul_channels);
423-
uint32_t mp_mul_samples, L_mul_samples;
424-
init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples);
401+
uint3 mul_ncols_packed = init_fastmodulo_values(mul_ncols);
402+
uint3 mul_nrows_packed = init_fastmodulo_values(mul_nrows);
403+
uint3 mul_nchannels_packed = init_fastmodulo_values(mul_nchannels);
404+
uint3 mul_nsamples_packed = init_fastmodulo_values(mul_nsamples);
425405
if (ncols < 1024) {
426406
const dim3 block_dims(256, 1, 1);
427407
rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
428408
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
429-
mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
430-
mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples);
409+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
431410
} else {
432411
const dim3 block_dims(1024, 1, 1);
433412
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
434413
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
435-
mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
436-
mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples);
414+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
437415
}
438416
} else {
439-
uint32_t mp_mul_cols, L_mul_cols;
440-
init_fastdiv_values(mul_ncols, mp_mul_cols, L_mul_cols);
441-
uint32_t mp_mul_rows, L_mul_rows;
442-
init_fastdiv_values(mul_nrows, mp_mul_rows, L_mul_rows);
443-
uint32_t mp_mul_channels, L_mul_channels;
444-
init_fastdiv_values(mul_nchannels, mp_mul_channels, L_mul_channels);
445-
uint32_t mp_mul_samples, L_mul_samples;
446-
init_fastdiv_values(mul_nsamples, mp_mul_samples, L_mul_samples);
447-
448-
uint32_t mp_add_cols, L_add_cols;
449-
init_fastdiv_values(add_ncols, mp_add_cols, L_add_cols);
450-
uint32_t mp_add_rows, L_add_rows;
451-
init_fastdiv_values(add_nrows, mp_add_rows, L_add_rows);
452-
uint32_t mp_add_channels, L_add_channels;
453-
init_fastdiv_values(add_nchannels, mp_add_channels, L_add_channels);
454-
uint32_t mp_add_samples, L_add_samples;
455-
init_fastdiv_values(add_nsamples, mp_add_samples, L_add_samples);
417+
uint3 mul_ncols_packed = init_fastmodulo_values(mul_ncols);
418+
uint3 mul_nrows_packed = init_fastmodulo_values(mul_nrows);
419+
uint3 mul_nchannels_packed = init_fastmodulo_values(mul_nchannels);
420+
uint3 mul_nsamples_packed = init_fastmodulo_values(mul_nsamples);
421+
422+
uint3 add_ncols_packed = init_fastmodulo_values(add_ncols);
423+
uint3 add_nrows_packed = init_fastmodulo_values(add_nrows);
424+
uint3 add_nchannels_packed = init_fastmodulo_values(add_nchannels);
425+
uint3 add_nsamples_packed = init_fastmodulo_values(add_nsamples);
456426
if (ncols < 1024) {
457427
const dim3 block_dims(256, 1, 1);
458428
rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
459429
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
460-
mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
461-
mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add,
462-
add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels,
463-
add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels,
464-
mp_add_samples, L_add_samples);
430+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
431+
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
432+
add_nchannels_packed, add_nsamples_packed);
465433
} else {
466434
const dim3 block_dims(1024, 1, 1);
467435
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
468436
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
469-
mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, mp_mul_cols, L_mul_cols,
470-
mp_mul_rows, L_mul_rows, mp_mul_channels, L_mul_channels, mp_mul_samples, L_mul_samples, add,
471-
add_stride_row, add_stride_channel, add_stride_sample, add_ncols, add_nrows, add_nchannels,
472-
add_nsamples, mp_add_cols, L_add_cols, mp_add_rows, L_add_rows, mp_add_channels, L_add_channels,
473-
mp_add_samples, L_add_samples);
437+
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
438+
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
439+
add_nchannels_packed, add_nsamples_packed);
474440
}
475441
}
476442
}

0 commit comments

Comments
 (0)