@@ -569,33 +569,30 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
569
569
// and a shift:
570
570
//
571
571
// n/d = (mulhi(n, mp) + n) >> L;
572
- static const uint2 init_fastdiv_values (uint32_t d) {
572
+ static const uint3 init_fastdiv_values (uint32_t d) {
573
573
// compute L = ceil(log2(d));
574
574
uint32_t L = 0 ;
575
575
while (L < 32 && (uint32_t { 1 } << L) < d) {
576
576
L++;
577
577
}
578
578
579
579
uint32_t mp = (uint32_t ) ((uint64_t { 1 } << 32 ) * ((uint64_t { 1 } << L) - d) / d + 1 );
580
- return make_uint2 (mp, L);
580
+ // pack divisor as well to reduce error surface
581
+ return make_uint3 (mp, L, d);
581
582
}
582
583
583
- static __device__ __forceinline__ uint32_t fastdiv (uint32_t n, const uint2 div_consts) {
584
+ static __device__ __forceinline__ uint32_t fastdiv (uint32_t n, const uint3 div_consts) {
585
+ // expects div_consts to contain <mp, L, divisor> in <x, y, z>
586
+ // div_consts.z is unused and optimized away by the compiler.
584
587
// Compute high 32 bits of n * mp
585
588
const uint32_t hi = __umulhi (n, div_consts.x );
586
- // Apply the formula
589
+ // add n, apply bit shift
587
590
return (hi + n) >> div_consts.y ;
588
591
}
589
592
590
- static const uint3 init_fastmodulo_values (uint32_t d) {
591
- // uint3 contains <mp, L, divisor> in <x, y, z>
592
- const uint2 fastdiv = init_fastdiv_values (d);
593
- return make_uint3 (fastdiv.x , fastdiv.y , d);
594
- }
595
-
596
593
static __device__ __forceinline__ uint32_t fastmodulo (uint32_t n, const uint3 modulo_consts) {
597
- // expects modulo_consts to contain <mp, L, divisor> in <x, y, z> (see init_fastmodulo_values function )
598
- return n - fastdiv (n, make_uint2 ( modulo_consts. x , modulo_consts. y ) ) * modulo_consts.z ;
594
+ // expects modulo_consts to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values )
595
+ return n - fastdiv (n, modulo_consts) * modulo_consts.z ;
599
596
}
600
597
601
598
typedef void (*dequantize_kernel_t )(const void * vx, const int64_t ib, const int iqs, float2 & v);
0 commit comments