@@ -3950,6 +3950,35 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
3950
3950
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
3951
3951
}
3952
3952
3953
+ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
3954
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
3955
+ #if QK_K == 256
3956
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
3957
+
3958
+ // iqs is 0...15
3959
+ const int ib32 = iqs/2;
3960
+ const int il = iqs%2;
3961
+ const uint16_t * q2 = bq2->qs + 4*ib32;
3962
+ const uint8_t * aux8 = (const uint8_t *)q2;
3963
+ const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
3964
+ const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
3965
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
3966
+ const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
3967
+ const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
3968
+ const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
3969
+ const int8_t * q8 = bq8_1[ib32].qs + 16*il;
3970
+ int sumi1 = 0, sumi2 = 0;
3971
+ for (int j = 0; j < 8; ++j) {
3972
+ sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
3973
+ sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
3974
+ }
3975
+ return d * (sumi1 + sumi2);
3976
+ #else
3977
+ assert(false);
3978
+ return 0.f;
3979
+ #endif
3980
+ }
3981
+
3953
3982
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
3954
3983
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
3955
3984
static __device__ __forceinline__ void mul_mat_q(
@@ -6044,6 +6073,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
6044
6073
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6045
6074
}
6046
6075
6076
+ static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6077
+ GGML_ASSERT(ncols % QK_K == 0);
6078
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6079
+ const dim3 block_nums(block_num_y, 1, 1);
6080
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6081
+ mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
6082
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6083
+ }
6084
+
6047
6085
static void ggml_mul_mat_q4_0_q8_1_cuda(
6048
6086
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
6049
6087
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -7608,6 +7646,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
7608
7646
case GGML_TYPE_Q6_K:
7609
7647
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7610
7648
break;
7649
+ case GGML_TYPE_IQ2_XXS:
7650
+ mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7651
+ break;
7611
7652
default:
7612
7653
GGML_ASSERT(false);
7613
7654
break;
0 commit comments