Skip to content

Commit f437f6a

Browse files
Fixed OpenLLaMA 3b CUDA mul_mat_vec_q
1 parent 061f5f8 commit f437f6a

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

ggml-cuda.cu

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
216216
#define CUDA_SCALE_BLOCK_SIZE 256
217217
#define CUDA_ROPE_BLOCK_SIZE 256
218218
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
219-
#define CUDA_QUANTIZE_BLOCK_SIZE 256
219+
#define CUDA_QUANTIZE_BLOCK_SIZE 128
220220
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
221221

222222
// dmmv = dequantize_mul_mat_vec
@@ -1174,16 +1174,12 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
11741174
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int k) {
11751175
const int i = blockDim.x*blockIdx.x + threadIdx.x;
11761176

1177-
if (i >= k) {
1178-
return;
1179-
}
1180-
11811177
block_q8_1 * y = (block_q8_1 *) vy;
11821178

1183-
const int ib = i / QK8_0; // block index
1184-
const int iqs = i % QK8_0; // quant index
1179+
const int ib = i / QK8_1; // block index
1180+
const int iqs = i % QK8_1; // quant index
11851181

1186-
const float xi = x[i];
1182+
const float xi = i < k ? x[i] : 0.0f;
11871183
float amax = fabsf(xi);
11881184
float sum = xi;
11891185

@@ -2359,8 +2355,10 @@ inline void ggml_cuda_op_mul_mat_vec(
23592355
#endif
23602356

23612357
if (use_mul_mat_vec_q) {
2358+
int64_t padded_row_size = ne00 + CUDA_QUANTIZE_BLOCK_SIZE - 1;
2359+
padded_row_size -= padded_row_size % CUDA_QUANTIZE_BLOCK_SIZE;
23622360
size_t as;
2363-
void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as);
2361+
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
23642362
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, cudaStream_main);
23652363

23662364
switch (src0->type) {
@@ -3105,7 +3103,11 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
31053103

31063104
void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
31073105
int nrows = ggml_nrows(tensor);
3106+
3107+
const int64_t ne0 = tensor->ne[0];
3108+
31083109
const size_t nb1 = tensor->nb[1];
3110+
31093111
ggml_backend backend = tensor->backend;
31103112
struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
31113113
memset(extra, 0, sizeof(*extra));
@@ -3134,7 +3136,11 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
31343136
int64_t nrows_split = row_high - row_low;
31353137

31363138
const size_t offset_split = row_low*nb1;
3137-
const size_t size = ggml_nbytes_split(tensor, nrows_split);
3139+
size_t size = ggml_nbytes_split(tensor, nrows_split);
3140+
if (ne0 % CUDA_QUANTIZE_BLOCK_SIZE != 0) {
3141+
size += (CUDA_QUANTIZE_BLOCK_SIZE - ne0 % CUDA_QUANTIZE_BLOCK_SIZE)
3142+
* ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
3143+
}
31383144

31393145
void * buf;
31403146
CUDA_CHECK(cudaMalloc(&buf, size));

0 commit comments

Comments
 (0)