@@ -1404,7 +1404,7 @@ static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 **
1404
1404
*x_dm = tile_x_d;
1405
1405
}
1406
1406
1407
- static __device__ __forceinline__ void load_tiles_q4_0 (
1407
+ template < bool need_check> static __device__ __forceinline__ void load_tiles_q4_0 (
1408
1408
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1409
1409
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
1410
1410
@@ -1420,7 +1420,11 @@ static __device__ __forceinline__ void load_tiles_q4_0(
1420
1420
1421
1421
#pragma unroll
1422
1422
for (int i0 = 0 ; i0 < GGML_CUDA_MMQ_Y; i0 += 8 ) {
1423
- const int i = min (i0 + i_offset, i_max);
1423
+ int i = i0 + i_offset;
1424
+
1425
+ if (need_check) {
1426
+ i = min (i, i_max);
1427
+ }
1424
1428
1425
1429
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
1426
1430
@@ -3609,8 +3613,14 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
3609
3613
const int block_num_y = (ncols_y + WARP_SIZE - 1 ) / WARP_SIZE;
3610
3614
const dim3 block_nums (block_num_x, block_num_y, 1 );
3611
3615
const dim3 block_dims (WARP_SIZE, WARP_SIZE/4 , 1 );
3612
- mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
3613
- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
3616
+
3617
+ if (nrows_x % GGML_CUDA_MMQ_Y == 0 ) {
3618
+ mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false >, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
3619
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
3620
+ } else {
3621
+ mul_mat_q<QK4_0, QR4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true >, VDR_q4_0_q8_1, vec_dot_q4_0_q8_1_mul_mat>
3622
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
3623
+ }
3614
3624
}
3615
3625
3616
3626
static void ggml_mul_mat_q4_1_q8_1_cuda (
0 commit comments