Skip to content

Commit efc4672

Browse files
template for check
1 parent c6a933c commit efc4672

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

ggml-cuda.cu

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,7 +1404,7 @@ static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 **
14041404
*x_dm = tile_x_d;
14051405
}
14061406

1407-
static __device__ __forceinline__ void load_tiles_q4_0(
1407+
template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
14081408
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
14091409
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
14101410

@@ -1420,7 +1420,11 @@ static __device__ __forceinline__ void load_tiles_q4_0(
14201420

14211421
#pragma unroll
14221422
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
1423-
const int i = min(i0 + i_offset, i_max);
1423+
int i = i0 + i_offset;
1424+
1425+
if (need_check) {
1426+
i = min(i, i_max);
1427+
}
14241428

14251429
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
14261430

@@ -3609,8 +3613,14 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
36093613
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
36103614
const dim3 block_nums(block_num_x, block_num_y, 1);
36113615
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
3612-
mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
3613-
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
3616+
3617+
if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
3618+
mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false>, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
3619+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
3620+
} else {
3621+
mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true>, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
3622+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
3623+
}
36143624
}
36153625

36163626
static void ggml_mul_mat_q4_1_q8_1_cuda(

0 commit comments

Comments
 (0)