Skip to content

Commit 0c94c41

Browse files
committed
iq2_xxs: quantized CUDA dot product (MMVQ)
We get TG-128 = 153.1 t/s
1 parent b393dd1 commit 0c94c41

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

ggml-cuda.cu

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3950,6 +3950,35 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
39503950
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
39513951
}
39523952

3953+
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
3954+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
3955+
#if QK_K == 256
3956+
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
3957+
3958+
// iqs is 0...15
3959+
const int ib32 = iqs/2;
3960+
const int il = iqs%2;
3961+
const uint16_t * q2 = bq2->qs + 4*ib32;
3962+
const uint8_t * aux8 = (const uint8_t *)q2;
3963+
const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
3964+
const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
3965+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
3966+
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
3967+
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
3968+
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
3969+
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
3970+
int sumi1 = 0, sumi2 = 0;
3971+
for (int j = 0; j < 8; ++j) {
3972+
sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
3973+
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
3974+
}
3975+
return d * (sumi1 + sumi2);
3976+
#else
3977+
assert(false);
3978+
return 0.f;
3979+
#endif
3980+
}
3981+
39533982
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
39543983
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
39553984
static __device__ __forceinline__ void mul_mat_q(
@@ -6044,6 +6073,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
60446073
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
60456074
}
60466075

6076+
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6077+
GGML_ASSERT(ncols % QK_K == 0);
6078+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6079+
const dim3 block_nums(block_num_y, 1, 1);
6080+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6081+
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
6082+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6083+
}
6084+
60476085
static void ggml_mul_mat_q4_0_q8_1_cuda(
60486086
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
60496087
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -7608,6 +7646,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
76087646
case GGML_TYPE_Q6_K:
76097647
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
76107648
break;
7649+
case GGML_TYPE_IQ2_XXS:
7650+
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7651+
break;
76117652
default:
76127653
GGML_ASSERT(false);
76137654
break;

0 commit comments

Comments
 (0)