Skip to content

Commit 8b1e937

Browse files
committed
Use uint3 for both fastdiv and fastmodulo
The compiler seems to reliably optimize away the unused .z component in the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3
1 parent 0a76b11 commit 8b1e937

File tree

2 files changed

+22
-25
lines changed

2 files changed

+22
-25
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -569,33 +569,30 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
569569
// and a shift:
570570
//
571571
// n/d = (mulhi(n, mp) + n) >> L;
572-
static const uint2 init_fastdiv_values(uint32_t d) {
572+
static const uint3 init_fastdiv_values(uint32_t d) {
573573
// compute L = ceil(log2(d));
574574
uint32_t L = 0;
575575
while (L < 32 && (uint32_t{ 1 } << L) < d) {
576576
L++;
577577
}
578578

579579
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
580-
return make_uint2(mp, L);
580+
// pack divisor as well to reduce error surface
581+
return make_uint3(mp, L, d);
581582
}
582583

583-
static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint2 div_consts) {
584+
static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 div_consts) {
585+
// expects div_consts to contain <mp, L, divisor> in <x, y, z>
586+
// div_consts.z is unused and optimized away by the compiler.
584587
// Compute high 32 bits of n * mp
585588
const uint32_t hi = __umulhi(n, div_consts.x);
586-
// Apply the formula
589+
// add n, apply bit shift
587590
return (hi + n) >> div_consts.y;
588591
}
589592

590-
static const uint3 init_fastmodulo_values(uint32_t d) {
591-
// uint3 contains <mp, L, divisor> in <x, y, z>
592-
const uint2 fastdiv = init_fastdiv_values(d);
593-
return make_uint3(fastdiv.x, fastdiv.y, d);
594-
}
595-
596593
static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 modulo_consts) {
597-
// expects modulo_consts to contain <mp, L, divisor> in <x, y, z> (see init_fastmodulo_values function)
598-
return n - fastdiv(n, make_uint2(modulo_consts.x, modulo_consts.y)) * modulo_consts.z;
594+
// expects modulo_consts to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
595+
return n - fastdiv(n, modulo_consts) * modulo_consts.z;
599596
}
600597

601598
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);

ggml/src/ggml-cuda/norm.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,10 @@ static void rms_norm_mul_f32_cuda(const float * x,
398398
return;
399399
}
400400
if (add == nullptr) {
401-
uint3 mul_ncols_packed = init_fastmodulo_values(mul_ncols);
402-
uint3 mul_nrows_packed = init_fastmodulo_values(mul_nrows);
403-
uint3 mul_nchannels_packed = init_fastmodulo_values(mul_nchannels);
404-
uint3 mul_nsamples_packed = init_fastmodulo_values(mul_nsamples);
401+
uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
402+
uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
403+
uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
404+
uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
405405
if (ncols < 1024) {
406406
const dim3 block_dims(256, 1, 1);
407407
rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
@@ -414,15 +414,15 @@ static void rms_norm_mul_f32_cuda(const float * x,
414414
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
415415
}
416416
} else {
417-
uint3 mul_ncols_packed = init_fastmodulo_values(mul_ncols);
418-
uint3 mul_nrows_packed = init_fastmodulo_values(mul_nrows);
419-
uint3 mul_nchannels_packed = init_fastmodulo_values(mul_nchannels);
420-
uint3 mul_nsamples_packed = init_fastmodulo_values(mul_nsamples);
421-
422-
uint3 add_ncols_packed = init_fastmodulo_values(add_ncols);
423-
uint3 add_nrows_packed = init_fastmodulo_values(add_nrows);
424-
uint3 add_nchannels_packed = init_fastmodulo_values(add_nchannels);
425-
uint3 add_nsamples_packed = init_fastmodulo_values(add_nsamples);
417+
uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
418+
uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
419+
uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
420+
uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
421+
422+
uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
423+
uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
424+
uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
425+
uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
426426
if (ncols < 1024) {
427427
const dim3 block_dims(256, 1, 1);
428428
rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(

0 commit comments

Comments
 (0)