Skip to content

Commit a6d1189

Browse files
ikawrakowKawrakow
andauthored
k_quants tuning for Falcon-7b (#2816)
* Make ggml-cuda.cu build with QK_K = 64 Using LLAMA_CUDA_FORCE_DMMV = ON and -nommq it runs and produces a meaningful result. * k_quants tuning for Falcon-7b --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent c48c5bb commit a6d1189

File tree

2 files changed

+51
-17
lines changed

2 files changed

+51
-17
lines changed

ggml-cuda.cu

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,11 @@ typedef struct {
306306
#define QI4_K (QK_K / (4*QR4_K))
307307
#ifdef GGML_QKK_64
308308
typedef struct {
309-
half d[2]; // super-block scales/mins
309+
half dm[2]; // super-block scales/mins
310310
uint8_t scales[2]; // 4-bit block scales/mins
311311
uint8_t qs[QK_K/2]; // 4--bit quants
312312
} block_q4_K;
313-
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
313+
static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding");
314314
#else
315315
typedef struct {
316316
half2 dm; // super-block scale for quantized scales/mins
@@ -737,8 +737,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
737737
const int tid = threadIdx.x;
738738
const uint8_t * q = x[i].qs;
739739
float * y = yy + i*QK_K;
740-
const float d = (float)x[i].d[0];
741-
const float m = (float)x[i].d[1];
740+
const float d = (float)x[i].dm[0];
741+
const float m = (float)x[i].dm[1];
742742
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
743743
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
744744
#endif
@@ -1155,8 +1155,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
11551155
const uint16_t * a = (const uint16_t *)x[i].scales;
11561156
aux16[0] = a[0] & 0x0f0f;
11571157
aux16[1] = (a[0] >> 4) & 0x0f0f;
1158-
const float d = (float)x[i].d[0];
1159-
const float m = (float)x[i].d[1];
1158+
const float d = (float)x[i].dm[0];
1159+
const float m = (float)x[i].dm[1];
11601160
float sum = 0.f;
11611161
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
11621162
sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
@@ -2845,8 +2845,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
28452845
aux16[0] = a[0] & 0x0f0f;
28462846
aux16[1] = (a[0] >> 4) & 0x0f0f;
28472847

2848-
const float dall = bq4_K->d[0];
2849-
const float dmin = bq4_K->d[1];
2848+
const float dall = bq4_K->dm[0];
2849+
const float dmin = bq4_K->dm[1];
28502850

28512851
const float d8_1 = __low2float(bq8_1[0].ds);
28522852
const float d8_2 = __low2float(bq8_1[1].ds);
@@ -2929,7 +2929,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
29292929

29302930
const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
29312931

2932+
#if QK_K == 256
29322933
x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
2934+
#else
2935+
x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
2936+
#endif
29332937
}
29342938

29352939
#pragma unroll
@@ -3119,7 +3123,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
31193123

31203124
const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
31213125

3126+
#if QK_K == 256
31223127
x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
3128+
#endif
31233129
}
31243130

31253131
#pragma unroll
@@ -4709,6 +4715,8 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
47094715
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
47104716
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
47114717

4718+
#if QK_K == 256
4719+
47124720
int id;
47134721
CUDA_CHECK(cudaGetDevice(&id));
47144722
const int compute_capability = g_compute_capabilities[id];
@@ -4740,6 +4748,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
47404748
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
47414749
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
47424750
}
4751+
#endif
47434752
}
47444753

47454754
static void ggml_mul_mat_q4_K_q8_1_cuda(

llama.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4776,7 +4776,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
47764776

47774777
if (name == tn(LLM_TENSOR_OUTPUT, "weight")) {
47784778
int nx = tensor->ne[0];
4779-
if (nx % QK_K == 0) {
4779+
if (model.arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
4780+
new_type = GGML_TYPE_Q8_0;
4781+
}
4782+
else if (new_type != GGML_TYPE_Q8_0) {
47804783
new_type = GGML_TYPE_Q6_K;
47814784
}
47824785
} else if (name.find("attn_v.weight") != std::string::npos) {
@@ -4800,17 +4803,39 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
48004803
} else if (name.find("ffn_down.weight") != std::string::npos) {
48014804
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
48024805
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
4803-
new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
4806+
new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K
4807+
: model.arch != LLM_ARCH_FALCON || use_more_bits(i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q4_K
4808+
: GGML_TYPE_Q3_K;
4809+
}
4810+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
4811+
new_type = model.arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
4812+
}
4813+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
4814+
if (model.arch == LLM_ARCH_FALCON) {
4815+
new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q6_K :
4816+
use_more_bits(i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
4817+
} else {
4818+
if (use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
4819+
}
4820+
}
4821+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
4822+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && model.arch != LLM_ARCH_FALCON && i_feed_forward_w2 < 4) {
4823+
new_type = GGML_TYPE_Q5_K;
48044824
}
4805-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
4806-
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
4807-
use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
4808-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
48094825
++i_feed_forward_w2;
48104826
} else if (name.find("attn_output.weight") != std::string::npos) {
4811-
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
4812-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K;
4813-
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
4827+
if (model.arch != LLM_ARCH_FALCON) {
4828+
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K;
4829+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K;
4830+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
4831+
} else {
4832+
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
4833+
}
4834+
}
4835+
else if (name.find("attn_qkv.weight") != std::string::npos) {
4836+
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
4837+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
4838+
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
48144839
}
48154840
else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) {
48164841
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;

0 commit comments

Comments
 (0)