@@ -1636,6 +1636,37 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
1636
1636
return vec_dot_q8_0_q8_1_impl (vi, ui, bq8_0->d , bq8_1->ds );
1637
1637
}
1638
1638
1639
+ static __device__ __forceinline__ void allocate_tiles_q8_0 (int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) {
1640
+
1641
+ __shared__ int tile_x_qs[(2 *WARP_SIZE) * (WARP_SIZE + 1 )];
1642
+ __shared__ half2 tile_x_d[(2 *WARP_SIZE) * (WARP_SIZE/QI8_0)];
1643
+
1644
+ *x_ql = tile_x_qs;
1645
+ *x_dm = tile_x_d;
1646
+ }
1647
+
1648
+ static __device__ __forceinline__ void load_tiles_q8_0 (
1649
+ const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1650
+ int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
1651
+
1652
+ const int kbx = k / QI8_0;
1653
+ const int kqsx = k % QI8_0;
1654
+
1655
+ const block_q8_0 * bx = ((block_q8_0 *) vx) + i*blocks_per_row + kbx;
1656
+
1657
+ x_ql[i * (WARP_SIZE + 1 ) + k] = get_int_from_int8 (bx->qs , kqsx);
1658
+ x_dm[i * (WARP_SIZE / QI8_0) + kbx].x = bx->d ;
1659
+ }
1660
+
1661
+ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat (
1662
+ const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc,
1663
+ const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
1664
+
1665
+ return vec_dot_q8_0_q8_1_impl (
1666
+ x_ql[i * (WARP_SIZE + 1 ) + k], y_qs[j*WARP_SIZE + k],
1667
+ x_dm[i * (WARP_SIZE/QI8_0) + k/QI8_0].x , y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
1668
+ }
1669
+
1639
1670
static __device__ __forceinline__ float vec_dot_q2_K_q8_1 (
1640
1671
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
1641
1672
@@ -1849,7 +1880,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
1849
1880
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
1850
1881
}
1851
1882
1852
- template <int qk, int qi, typename block_q_t ,
1883
+ template <int qk, int qr, int qi, typename block_q_t ,
1853
1884
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, vec_dot_q_mul_mat_cuda_t vec_dot>
1854
1885
static __global__ void mul_mat_q (
1855
1886
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
@@ -1880,8 +1911,8 @@ static __global__ void mul_mat_q(
1880
1911
1881
1912
allocate_tiles (&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
1882
1913
1883
- __shared__ int tile_y_qs[(WARP_SIZE) * (2 *WARP_SIZE)];
1884
- __shared__ half2 tile_y_ds[(WARP_SIZE) * (2 *WARP_SIZE/QI8_1)];
1914
+ __shared__ int tile_y_qs[(WARP_SIZE) * (qr *WARP_SIZE)];
1915
+ __shared__ half2 tile_y_ds[(WARP_SIZE) * (qr *WARP_SIZE/QI8_1)];
1885
1916
1886
1917
float sum[2 ][4 ] = {0 .0f };
1887
1918
@@ -1892,22 +1923,20 @@ static __global__ void mul_mat_q(
1892
1923
i + tid_y, tid_x, blocks_per_row);
1893
1924
}
1894
1925
1895
- const int iby0 = tid_x / QI8_1;
1896
- const int iby1 = iby0 + WARP_SIZE / QI8_1;
1897
1926
const int iqsy = sizeof (int ) * (tid_x % QI8_1);
1898
1927
1899
- for (int i = 0 ; i < WARP_SIZE; i += 8 ) {
1900
- const int col_y_eff = min (col_y_0 + tid_y + i, ncols_y-1 ); // to prevent out-of-bounds memory accesses
1928
+ for (int ir = 0 ; ir < qr; ++ir) {
1929
+ const int kqs = ir*WARP_SIZE + tid_x;
1930
+ const int kby = kqs / QI8_1;
1901
1931
1902
- const block_q8_1 * __restrict__ by0 = &y[col_y_eff*blocks_per_row + ib0 + iby0];
1932
+ for (int i = 0 ; i < WARP_SIZE; i += 8 ) {
1933
+ const int col_y_eff = min (col_y_0 + tid_y + i, ncols_y-1 ); // to prevent out-of-bounds memory accesses
1903
1934
1904
- tile_y_qs[(tid_y + i) * (2 *WARP_SIZE) + tid_x] = *((int *) &by0->qs [iqsy]);
1905
- tile_y_ds[(tid_y + i) * (2 *WARP_SIZE/QI8_1) + iby0] = by0->ds ;
1935
+ const block_q8_1 * by0 = &y[col_y_eff*blocks_per_row + ib0 + kby];
1906
1936
1907
- const block_q8_1 * __restrict__ by1 = &y[col_y_eff*blocks_per_row + ib0 + iby1];
1908
-
1909
- tile_y_qs[(tid_y + i) * (2 *WARP_SIZE) + tid_x + WARP_SIZE] = *((int *) &by1->qs [iqsy]);
1910
- tile_y_ds[(tid_y + i) * (2 *WARP_SIZE/QI8_1) + iby1] = by1->ds ;
1937
+ tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = *((int *) &by0->qs [iqsy]);
1938
+ tile_y_ds[(tid_y + i) * (qr*WARP_SIZE/QI8_1) + kby] = by0->ds ;
1939
+ }
1911
1940
}
1912
1941
1913
1942
__syncthreads ();
@@ -2633,31 +2662,39 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(const void * vx, const void * vy, float
2633
2662
const int block_num_y = (ncols_y + WARP_SIZE - 1 ) / WARP_SIZE;
2634
2663
const dim3 block_nums (block_num_x, block_num_y, 1 );
2635
2664
const dim3 block_dims (WARP_SIZE, WARP_SIZE/4 , 1 );
2636
- mul_mat_q<QK4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0, vec_dot_q4_0_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2665
+ mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0, vec_dot_q4_0_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2637
2666
}
2638
2667
2639
2668
static void ggml_mul_mat_q4_1_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2640
2669
const int block_num_x = (nrows_x + 2 *WARP_SIZE - 1 ) / (2 *WARP_SIZE);
2641
2670
const int block_num_y = (ncols_y + WARP_SIZE - 1 ) / WARP_SIZE;
2642
2671
const dim3 block_nums (block_num_x, block_num_y, 1 );
2643
2672
const dim3 block_dims (WARP_SIZE, WARP_SIZE/4 , 1 );
2644
- mul_mat_q<QK4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1, vec_dot_q4_1_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2673
+ mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1, vec_dot_q4_1_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2645
2674
}
2646
2675
2647
2676
static void ggml_mul_mat_q5_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2648
2677
const int block_num_x = (nrows_x + 2 *WARP_SIZE - 1 ) / (2 *WARP_SIZE);
2649
2678
const int block_num_y = (ncols_y + WARP_SIZE - 1 ) / WARP_SIZE;
2650
2679
const dim3 block_nums (block_num_x, block_num_y, 1 );
2651
2680
const dim3 block_dims (WARP_SIZE, WARP_SIZE/4 , 1 );
2652
- mul_mat_q<QK5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0, vec_dot_q5_0_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2681
+ mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0, vec_dot_q5_0_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2653
2682
}
2654
2683
2655
2684
static void ggml_mul_mat_q5_1_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2656
2685
const int block_num_x = (nrows_x + 2 *WARP_SIZE - 1 ) / (2 *WARP_SIZE);
2657
2686
const int block_num_y = (ncols_y + WARP_SIZE - 1 ) / WARP_SIZE;
2658
2687
const dim3 block_nums (block_num_x, block_num_y, 1 );
2659
2688
const dim3 block_dims (WARP_SIZE, WARP_SIZE/4 , 1 );
2660
- mul_mat_q<QK5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1, vec_dot_q5_1_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2689
+ mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1, vec_dot_q5_1_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2690
+ }
2691
+
2692
+ static void ggml_mul_mat_q8_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2693
+ const int block_num_x = (nrows_x + 2 *WARP_SIZE - 1 ) / (2 *WARP_SIZE);
2694
+ const int block_num_y = (ncols_y + WARP_SIZE - 1 ) / WARP_SIZE;
2695
+ const dim3 block_nums (block_num_x, block_num_y, 1 );
2696
+ const dim3 block_dims (WARP_SIZE, WARP_SIZE/4 , 1 );
2697
+ mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0, vec_dot_q8_0_q8_1_mul_mat><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2661
2698
}
2662
2699
2663
2700
static void ggml_mul_mat_p021_f16_f32_cuda (const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
@@ -3123,6 +3160,9 @@ inline void ggml_cuda_op_mul_mat_q(
3123
3160
case GGML_TYPE_Q5_1:
3124
3161
ggml_mul_mat_q5_1_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
3125
3162
break ;
3163
+ case GGML_TYPE_Q8_0:
3164
+ ggml_mul_mat_q8_0_q8_1_cuda (src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
3165
+ break ;
3126
3166
default :
3127
3167
GGML_ASSERT (false );
3128
3168
break ;
@@ -3873,7 +3913,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
3873
3913
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul_mat_vec, false , false );
3874
3914
} else {
3875
3915
if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 ||
3876
- src0->type == GGML_TYPE_Q5_1) {
3916
+ src0->type == GGML_TYPE_Q5_1 || src0-> type == GGML_TYPE_Q8_0 ) {
3877
3917
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul_mat_q, false , false );
3878
3918
} else {
3879
3919
ggml_cuda_op (src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true , false );
0 commit comments