Skip to content

Commit cf0a505

Browse files
q8_0 works
1 parent 0c053fd commit cf0a505

File tree

1 file changed

+59
-19
lines changed

1 file changed

+59
-19
lines changed

ggml-cuda.cu

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,37 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
16361636
return vec_dot_q8_0_q8_1_impl(vi, ui, bq8_0->d, bq8_1->ds);
16371637
}
16381638

1639+
static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) {
1640+
1641+
__shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)];
1642+
__shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI8_0)];
1643+
1644+
*x_ql = tile_x_qs;
1645+
*x_dm = tile_x_d;
1646+
}
1647+
1648+
static __device__ __forceinline__ void load_tiles_q8_0(
1649+
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1650+
int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
1651+
1652+
const int kbx = k / QI8_0;
1653+
const int kqsx = k % QI8_0;
1654+
1655+
const block_q8_0 * bx = ((block_q8_0 *) vx) + i*blocks_per_row + kbx;
1656+
1657+
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bx->qs, kqsx);
1658+
x_dm[i * (WARP_SIZE / QI8_0) + kbx].x = bx->d;
1659+
}
1660+
1661+
static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
1662+
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc,
1663+
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
1664+
1665+
return vec_dot_q8_0_q8_1_impl(
1666+
x_ql[i * (WARP_SIZE + 1) + k], y_qs[j*WARP_SIZE + k],
1667+
x_dm[i * (WARP_SIZE/QI8_0) + k/QI8_0].x, y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
1668+
}
1669+
16391670
static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
16401671
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
16411672

@@ -1849,7 +1880,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
18491880
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
18501881
}
18511882

1852-
template <int qk, int qi, typename block_q_t,
1883+
template <int qk, int qr, int qi, typename block_q_t,
18531884
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, vec_dot_q_mul_mat_cuda_t vec_dot>
18541885
static __global__ void mul_mat_q(
18551886
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
@@ -1880,8 +1911,8 @@ static __global__ void mul_mat_q(
18801911

18811912
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
18821913

1883-
__shared__ int tile_y_qs[(WARP_SIZE) * (2*WARP_SIZE)];
1884-
__shared__ half2 tile_y_ds[(WARP_SIZE) * (2*WARP_SIZE/QI8_1)];
1914+
__shared__ int tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)];
1915+
__shared__ half2 tile_y_ds[(WARP_SIZE) * (qr*WARP_SIZE/QI8_1)];
18851916

18861917
float sum[2][4] = {0.0f};
18871918

@@ -1892,22 +1923,20 @@ static __global__ void mul_mat_q(
18921923
i + tid_y, tid_x, blocks_per_row);
18931924
}
18941925

1895-
const int iby0 = tid_x / QI8_1;
1896-
const int iby1 = iby0 + WARP_SIZE / QI8_1;
18971926
const int iqsy = sizeof(int) * (tid_x % QI8_1);
18981927

1899-
for (int i = 0; i < WARP_SIZE; i += 8) {
1900-
const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
1928+
for (int ir = 0; ir < qr; ++ir) {
1929+
const int kqs = ir*WARP_SIZE + tid_x;
1930+
const int kby = kqs / QI8_1;
19011931

1902-
const block_q8_1 * __restrict__ by0 = &y[col_y_eff*blocks_per_row + ib0 + iby0];
1932+
for (int i = 0; i < WARP_SIZE; i += 8) {
1933+
const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
19031934

1904-
tile_y_qs[(tid_y + i) * (2*WARP_SIZE) + tid_x] = *((int *) &by0->qs[iqsy]);
1905-
tile_y_ds[(tid_y + i) * (2*WARP_SIZE/QI8_1) + iby0] = by0->ds;
1935+
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_row + ib0 + kby];
19061936

1907-
const block_q8_1 * __restrict__ by1 = &y[col_y_eff*blocks_per_row + ib0 + iby1];
1908-
1909-
tile_y_qs[(tid_y + i) * (2*WARP_SIZE) + tid_x + WARP_SIZE] = *((int *) &by1->qs[iqsy]);
1910-
tile_y_ds[(tid_y + i) * (2*WARP_SIZE/QI8_1) + iby1] = by1->ds;
1937+
tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = *((int *) &by0->qs[iqsy]);
1938+
tile_y_ds[(tid_y + i) * (qr*WARP_SIZE/QI8_1) + kby] = by0->ds;
1939+
}
19111940
}
19121941

19131942
__syncthreads();
@@ -2633,31 +2662,39 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(const void * vx, const void * vy, float
26332662
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
26342663
const dim3 block_nums(block_num_x, block_num_y, 1);
26352664
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
2636-
mul_mat_q<QK4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0, vec_dot_q4_0_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2665+
mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0, vec_dot_q4_0_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
26372666
}
26382667

26392668
static void ggml_mul_mat_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
26402669
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
26412670
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
26422671
const dim3 block_nums(block_num_x, block_num_y, 1);
26432672
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
2644-
mul_mat_q<QK4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1, vec_dot_q4_1_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2673+
mul_mat_q<QK4_1, QR4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1, vec_dot_q4_1_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
26452674
}
26462675

26472676
static void ggml_mul_mat_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
26482677
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
26492678
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
26502679
const dim3 block_nums(block_num_x, block_num_y, 1);
26512680
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
2652-
mul_mat_q<QK5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0, vec_dot_q5_0_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2681+
mul_mat_q<QK5_0, QR5_0, QI5_0, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0, vec_dot_q5_0_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
26532682
}
26542683

26552684
static void ggml_mul_mat_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
26562685
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
26572686
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
26582687
const dim3 block_nums(block_num_x, block_num_y, 1);
26592688
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
2660-
mul_mat_q<QK5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1, vec_dot_q5_1_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2689+
mul_mat_q<QK5_1, QR5_1, QI5_1, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1, vec_dot_q5_1_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2690+
}
2691+
2692+
static void ggml_mul_mat_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2693+
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
2694+
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
2695+
const dim3 block_nums(block_num_x, block_num_y, 1);
2696+
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
2697+
mul_mat_q<QK8_0, QR8_0, QI8_0, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0, vec_dot_q8_0_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
26612698
}
26622699

26632700
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
@@ -3123,6 +3160,9 @@ inline void ggml_cuda_op_mul_mat_q(
31233160
case GGML_TYPE_Q5_1:
31243161
ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
31253162
break;
3163+
case GGML_TYPE_Q8_0:
3164+
ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
3165+
break;
31263166
default:
31273167
GGML_ASSERT(false);
31283168
break;
@@ -3873,7 +3913,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
38733913
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
38743914
} else {
38753915
if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 ||
3876-
src0->type == GGML_TYPE_Q5_1) {
3916+
src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0) {
38773917
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
38783918
} else {
38793919
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);

0 commit comments

Comments
 (0)