From d2f12ac354552bcfba1dbc9c8593296d81b70452 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 21 Jun 2023 12:43:44 +0300 Subject: [PATCH 01/35] k_quants: WIP super-blocks with 64 weights --- CMakeLists.txt | 4 + Makefile | 4 + k_quants.c | 243 +++++++++++++++++++++++++++++++++++++++++++++---- k_quants.h | 39 ++++++-- 4 files changed, 268 insertions(+), 22 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cc7560a7ae54e..6aae8e1661f3d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,6 +75,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) +option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -292,6 +293,9 @@ endif() if (LLAMA_K_QUANTS) set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h) add_compile_definitions(GGML_USE_K_QUANTS) + if (LLAMA_QKK_64) + add_compile_definitions(GGML_QKK_64) + endif() endif() if (LLAMA_CLBLAST) diff --git a/Makefile b/Makefile index 5dd676fada417..a163bde40f66d 100644 --- a/Makefile +++ b/Makefile @@ -131,6 +131,10 @@ ifndef LLAMA_NO_K_QUANTS CFLAGS += -DGGML_USE_K_QUANTS CXXFLAGS += -DGGML_USE_K_QUANTS OBJS += k_quants.o +ifdef LLAMA_QKK_64 + CFLAGS += -DGGML_QKK_64 + CXXFLAGS += -DGGML_QKK_64 +endif endif ifndef LLAMA_NO_ACCELERATE diff --git a/k_quants.c b/k_quants.c index a48c821710cbb..e3313219ef0e4 100644 --- a/k_quants.c +++ b/k_quants.c @@ -330,11 +330,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict } } +#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); } } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif x += QK_K; @@ -352,6 +358,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int const uint8_t * q = x[i].qs; +#if QK_K == 256 int is = 0; float dl, ml; for (int n = 0; n < QK_K; n += 128) { @@ -370,7 +377,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int } q += 32; } - +#else + float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); + float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); + float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); + float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); + for (int l = 0; l < 16; ++l) { + y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1; + y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2; + y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3; + y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4; + } + y += QK_K; +#endif } } @@ -412,6 +431,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict } } +#if QK_K == 256 memset(y[i].scales, 0, 12); if (max_scale) { float iscale = -32.f/max_scale; @@ -445,9 +465,36 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict L[16*j + ii] = l + 4; } } +#else + if (max_scale) { + float iscale = -128.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + l = MAX(-128, MIN(127, l)); + y[i].scales[j] = l; + } + y[i].d = ggml_fp32_to_fp16(1/iscale); + } else { + for (int j = 0; j < QK_K/16; ++j) { + y[i].scales[j] = 0; + } + y[i].d = ggml_fp32_to_fp16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } +#endif memset(y[i].hmask, 0, QK_K/8); - // We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc. + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. int m = 0; uint8_t hm = 1; for (int j = 0; j < QK_K; ++j) { @@ -459,19 +506,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict m = 0; hm <<= 1; } } +#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); } } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif x += QK_K; } } +#if QK_K == 256 void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { assert(k % QK_K == 0); - assert(QK_K == 256); const int nb = k / QK_K; const uint32_t kmask1 = 0x03030303; @@ -519,6 +572,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } } +#else +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + assert(QK_K == 64); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d_all = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + + const float d1 = d_all * x[i].scales[0]; + const float d2 = d_all * x[i].scales[1]; + const float d3 = d_all * x[i].scales[2]; + const float d4 = d_all * x[i].scales[3]; + + for (int l=0; l<8; ++l) { + uint8_t h = hm[l]; + y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); + y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); + y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); + y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); + y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); + y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); + y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); + y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); + } + y += QK_K; + } +} +#endif void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { quantize_row_q3_K_reference(x, vy, k); @@ -544,11 +630,14 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict const int nb = k / QK_K; uint8_t L[QK_K]; +#if QK_K == 256 float mins[QK_K/32]; float scales[QK_K/32]; +#endif for (int i = 0; i < nb; i++) { +#if QK_K == 256 float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { @@ -594,9 +683,28 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict L[32*j + ii] = l; } } +#else + for (int j = 0; j < QK_K/32; ++j) { + float min; + float scale = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &min, 5); + y[i].d[2*j+0] = ggml_fp32_to_fp16(scale); + y[i].d[2*j+1] = ggml_fp32_to_fp16(min); + } + + for (int j = 0; j < QK_K/32; ++j) { + const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]); + if (!d) continue; + const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]); + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } +#endif uint8_t * q = y[i].qs; for (int j = 0; j < QK_K; j += 64) { - for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4); + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); } x += QK_K; @@ -610,11 +718,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int for (int i = 0; i < nb; i++) { - const float d = ggml_fp16_to_fp32(x[i].d); - const float min = ggml_fp16_to_fp32(x[i].dmin); - const uint8_t * q = x[i].qs; +#if QK_K == 256 + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + int is = 0; uint8_t sc, m; for (int j = 0; j < QK_K; j += 64) { @@ -626,6 +736,15 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; q += 32; is += 2; } +#else + float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]); + float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]); + for (int l = 0; l < 32; ++l) { + y[l+ 0] = d1 * (q[l] & 0xF) - m1; + y[l+32] = d2 * (q[l] >> 4) - m2; + } + y += QK_K; +#endif } } @@ -654,11 +773,15 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict const int nb = k / QK_K; uint8_t L[QK_K]; +#if QK_K == 256 float mins[QK_K/32]; float scales[QK_K/32]; +#endif for (int i = 0; i < nb; i++) { +#if QK_K == 256 + float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { @@ -725,6 +848,42 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict m1 <<= 2; m2 <<= 2; ql += 32; } +#else + for (int j = 0; j < QK_K/32; ++j) { + float min; + float scale = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &min, 5); + y[i].d[2*j+0] = ggml_fp32_to_fp16(scale); + y[i].d[2*j+1] = ggml_fp32_to_fp16(min); + } + for (int j = 0; j < QK_K/32; ++j) { + const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]); + if (!d) continue; + const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]); + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + for (int j = 0; j < 32; ++j) { + int jm = j%8; + int is = j/8; + int l1 = L[j]; + if (l1 > 15) { + l1 -= 16; qh[jm] |= (1 << is); + } + int l2 = L[j + 32]; + if (l2 > 15) { + l2 -= 16; qh[jm] |= (1 << (4 + is)); + } + ql[j] = l1 | (l2 << 4); + } +#endif x += QK_K; @@ -737,12 +896,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int for (int i = 0; i < nb; i++) { - const float d = ggml_fp16_to_fp32(x[i].d); - const float min = ggml_fp16_to_fp32(x[i].dmin); - const uint8_t * ql = x[i].qs; const uint8_t * qh = x[i].qh; +#if QK_K == 256 + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + int is = 0; uint8_t sc, m; uint8_t u1 = 1, u2 = 2; @@ -756,6 +917,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int ql += 32; is += 2; u1 <<= 2; u2 <<= 2; } +#else + float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]); + float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]); + for (int l = 0; l < 8; ++l) { + y[l+ 0] = d1 * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - m1; + y[l+ 8] = d1 * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - m1; + y[l+16] = d1 * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - m1; + y[l+24] = d1 * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - m1; + y[l+32] = d2 * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - m2; + y[l+40] = d2 * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - m2; + y[l+48] = d2 * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - m2; + y[l+56] = d2 * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - m2; + } + y += QK_K; +#endif } } @@ -823,6 +999,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict uint8_t * restrict ql = y[i].ql; uint8_t * restrict qh = y[i].qh; +#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { const uint8_t q1 = L[j + l + 0] & 0xF; @@ -836,6 +1013,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict ql += 64; qh += 32; } +#else + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[l + 0] & 0xF; + const uint8_t q2 = L[l + 32] & 0xF; + ql[l] = q1 | (q2 << 4); + } + for (int l = 0; l < 16; ++l) { + qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6); + } +#endif x += QK_K; @@ -854,6 +1041,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int const uint8_t * restrict qh = x[i].qh; const int8_t * restrict sc = x[i].scales; +#if QK_K == 256 for (int n = 0; n < QK_K; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; @@ -871,6 +1059,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int qh += 32; sc += 8; } +#else + for (int l = 0; l < 16; ++l) { + const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l+ 0] = d * sc[0] * q1; + y[l+16] = d * sc[1] * q2; + y[l+32] = d * sc[2] * q3; + y[l+48] = d * sc[3] * q4; + } + y += 64; +#endif } } @@ -1611,18 +1812,23 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { +#if QK_K == 256 const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; +#else + // TODO + const float d = 0; const float dmin = 0; +#endif + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); @@ -1840,18 +2046,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; +#if QK_K == 256 + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; +#else + // TODO + const float d = 0, dmin = 0; +#endif const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); diff --git a/k_quants.h b/k_quants.h index 10a0baac73a07..2f633389b6ec0 100644 --- a/k_quants.h +++ b/k_quants.h @@ -7,7 +7,13 @@ #include // Super-block size +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else #define QK_K 256 +#define K_SCALE_SIZE 12 +#endif // // Super-block quantization structures @@ -32,35 +38,56 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits +#ifdef GGML_QKK_64 + int8_t scales[K_SCALE_SIZE]; +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif ggml_fp16_t d; // super-block scale } block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); // 4-bit quantization // 16 blocks of 32 elements each // weight is represented as x = a * q + b // Effectively 4.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding"); +#else typedef struct { ggml_fp16_t d; // super-block scale for quantized scales ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); // 5-bit quantization // 16 blocks of 32 elements each // weight is represented as x = a * q + b // Effectively 5.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#else typedef struct { ggml_fp16_t d; // super-block scale for quantized scales ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif // 6-bit quantization // weight is represented as x = a * q From 9fe2a2b1dbfef1823c7484842f2a598e3d19d9bc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 21 Jun 2023 13:41:30 +0300 Subject: [PATCH 02/35] k_quants: WIP super-blocks with 64 weights Q6_K scalar and AVX2 works --- k_quants.c | 247 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 246 insertions(+), 1 deletion(-) diff --git a/k_quants.c b/k_quants.c index e3313219ef0e4..2138512545df6 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2184,7 +2184,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri } - +#if QK_K == 256 void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -2453,3 +2453,248 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri *s = sumf; #endif } + +#else + +void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + //const int8x16_t m32s = vdupq_n_s8(32); + + const uint8x16_t mone = vdupq_n_u8(3); + + int8x16x4_t q6bytes; + uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); + + int32_t isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; + uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; + int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + +#else + + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + q8bytes = vld1q_s8_x4(q8); q8 += 64; + + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + + //for (int l = 0; l < 4; ++l) { + // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); + // isum += vaddvq_s32(p) * *scale++; + //} +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + } + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); + + } + *s = sum; + +#elif defined z__AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m256i sumi = _mm256_setzero_si256(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int l = 0; l < 16; ++l) { + a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#endif From 1f6195c2f2eba5baf92e32100b259833b1f4f71f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 21 Jun 2023 14:31:21 +0300 Subject: [PATCH 03/35] k_quants: WIP super-blocks with 64 weights Q4_K scalar and AVX2 works --- k_quants.c | 172 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 1 deletion(-) diff --git a/k_quants.c b/k_quants.c index 2138512545df6..3ff196cc51fb0 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1702,6 +1702,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri } +#if QK_K == 256 void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -1932,6 +1933,175 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri *s = sumf; #endif } +#else +void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + int8x16x2_t q4bytes; + int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)}; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + //int32x4_t isum = mzero; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; +#else + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; + + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; + +#endif + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#elif defined z__AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m1 = _mm256_set1_epi16(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + for (int i = 0; i < nb; ++i) { + + summs += y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + + ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + + const __m256i p32l = _mm256_madd_epi16(m1, p16l); + const __m256i p32h = _mm256_madd_epi16(m1, p16h); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[0])), _mm256_cvtepi32_ps(p32l), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[2])), _mm256_cvtepi32_ps(p32h), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) a[l+ 0] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l+32] = (int8_t)(q4[l] >> 4); + + sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + + ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + + for (int j = 0; j < QK_K/32; ++j) { + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) sums[l] += d * aux16[l]; + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -2601,7 +2771,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri } *s = sum; -#elif defined z__AVX2__ +#elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); const __m256i m2 = _mm256_set1_epi8(3); From aebd5471e934ea7474906203f90e42f77c30b618 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 21 Jun 2023 15:20:18 +0300 Subject: [PATCH 04/35] k_quants: WIP super-blocks with 64 weights Q2_K scalar and AVX2 works. Q2_K is way too slow (it is actually slower than the scalar implementation) --- k_quants.c | 184 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 183 insertions(+), 1 deletion(-) diff --git a/k_quants.c b/k_quants.c index 3ff196cc51fb0..42fa9a402d80e 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1203,6 +1203,7 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +#if QK_K == 256 void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const block_q2_K * restrict x = vx; @@ -1402,6 +1403,187 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri #endif } +#else + +void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); + const uint8x16_t m4 = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + + int8x16x2_t q2bytes; + uint8_t aux[16]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); + + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); + const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + + int isum = 0; + int is = 0; + +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#if defined(__ARM_FEATURE_DOTPROD) +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; +#else +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + {\ + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ + isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ + } +#endif + +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); + + + for (int j = 0; j < QK_K/128; ++j) { + + const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; + + int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + sum += d * isum; + + } + + *s = sum; + +#elif defined z__AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; + + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; + + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + + const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); + const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); + const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); + const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); + } + + *s = hsum_float_8(acc) + summs; + +#else + + float sumf = 0; + + int isum[4]; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < QK_K/16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + isum[0] = isum[1] = isum[2] = isum[3] = 0; + for (int l = 0; l < 16; ++l) { + isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); + isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); + isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); + isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); + } + for (int l = 0; l < 4; ++l) { + isum[l] *= (sc[l] & 0xF); + } + sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; + } + *s = sumf; +#endif +} +#endif + void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -2029,7 +2211,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri *s = sumf; -#elif defined z__AVX2__ +#elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); const __m256i m1 = _mm256_set1_epi16(1); From 2b2ab31a8988345a5b8c066d5ce32e0a7ae38b37 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 21 Jun 2023 16:53:06 +0300 Subject: [PATCH 05/35] k_quants: WIP super-blocks with 64 weights Q3_K scalar and AVX2 works. --- Makefile | 5 +- k_quants.c | 239 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 242 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index a163bde40f66d..bda11791d5e23 100644 --- a/Makefile +++ b/Makefile @@ -43,8 +43,11 @@ endif # keep standard at C11 and C++11 # -Ofast tends to produce faster code, but may not be available for some compilers. -#OPT = -Ofast +ifdef LLAMA_FAST +OPT = -Ofast +else OPT = -O3 +endif CFLAGS = -I. $(OPT) -std=c11 -fPIC CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC LDFLAGS = diff --git a/k_quants.c b/k_quants.c index 42fa9a402d80e..32eea36607d89 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1495,7 +1495,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri *s = sum; -#elif defined z__AVX2__ +#elif defined __AVX2__ const __m256i m3 = _mm256_set1_epi8(3); @@ -1584,6 +1584,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri } #endif +#if QK_K == 256 void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -1884,6 +1885,242 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri } +#else + +void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); + + uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + for (int j = 0; j < QK_K/128; ++j) { + + const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; + const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; + const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; +#else + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } + + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m1 = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); + const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); + + // Set up scales + memcpy(&aux64, x[i].hmask, 8); + + __m256i q3h_0 = _mm256_set_m128i(_mm_set_epi64x(aux64 >> 3, aux64 >> 2), _mm_set_epi64x(aux64 >> 1, aux64 >> 0)); + __m256i q3h_1 = _mm256_set_m128i(_mm_set_epi64x(aux64 >> 7, aux64 >> 6), _mm_set_epi64x(aux64 >> 5, aux64 >> 4)); + q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); + q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits), m3); + const __m256i q3l_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q3bits, 6), _mm_srli_epi16(q3bits, 4)), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + // multiply with scales + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + p16_0 = _mm256_add_epi32(p16_0, p16_1); + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); + + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 8; ++l) { + a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); + a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); + a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); + a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); + a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); + a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); + a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); + a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); + } + + memset(aux32, 0, 8*sizeof(int32_t)); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + int sc = x[i].scales[j]; + for (int l = 0; l < 8; ++l) aux32[l] += sc * aux16[l]; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} +#endif + #if QK_K == 256 void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); From bcf8c5c384ed339fd8da232aa80fa6b0a7c8729b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 21 Jun 2023 18:28:40 +0300 Subject: [PATCH 06/35] k_quants: WIP super-blocks with 64 weights Q5_K scalar and AVX2 works, and with that all k_quants are done on AVX2 and scalar --- k_quants.c | 212 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 198 insertions(+), 14 deletions(-) diff --git a/k_quants.c b/k_quants.c index 32eea36607d89..86c224fc97345 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2487,8 +2487,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri #else - int8_t aux8[QK_K]; - int16_t aux16[8]; + uint8_t aux8[QK_K]; + int16_t aux16[16]; float sums [8]; memset(sums, 0, 8*sizeof(float)); @@ -2496,24 +2496,20 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - int8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) a[l+ 0] = (int8_t)(q4[l] & 0xF); - for (int l = 0; l < 32; ++l) a[l+32] = (int8_t)(q4[l] >> 4); + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; + for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); for (int j = 0; j < QK_K/32; ++j) { const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) sums[l] += d * aux16[l]; + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[l+8]); } } for (int l = 0; l < 8; ++l) sumf += sums[l]; @@ -2522,6 +2518,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri } #endif +#if QK_K == 256 void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -2772,6 +2769,193 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri #endif } +#else + +void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int32x4_t mzero = vdupq_n_s32(0); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); + + uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; +#endif + } + + sumf += d * sumi - dmin * sumi_mins; + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m1 = _mm256_set1_epi16(1); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d1 = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); + const float d2 = y[i].d * ggml_fp16_to_fp32(x[i].d[2]); + summs -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + + ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)(q5+ 0)); + + const uint64_t * restrict sc = (const uint64_t *)x[i].qh; + const __m128i hbits_0 = _mm_set_epi64x(sc[0] >> 1, sc[0] >> 0); + const __m128i hbits_1 = _mm_set_epi64x(sc[0] >> 3, sc[0] >> 2); + const __m128i hbits_2 = _mm_set_epi64x(sc[0] >> 5, sc[0] >> 4); + const __m128i hbits_3 = _mm_set_epi64x(sc[0] >> 7, sc[0] >> 6); + + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_1, hbits_0), mone), 4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_3, hbits_2), mone), 4); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16_0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_0, q8_0)); + const __m256i p16_1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_1, q8_1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d1), _mm256_cvtepi32_ps(p16_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d2), _mm256_cvtepi32_ps(p16_1), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#else + + + uint8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) { + a[l+ 0] = q4[l] & 0xF; + a[l+32] = q4[l] >> 4; + } + for (int is = 0; is < 8; ++is) { + uint8_t m = 1 << is; + for (int l = 0; l < 8; ++l) a[8*is + l] += (hm[l] & m ? 16 : 0); + } + + sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + + ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + + for (int j = 0; j < QK_K/32; ++j) { + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[8+l]); + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + #if QK_K == 256 void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { From c6c35366bfa00d59f709fb8a2adc511cbe6e3e32 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 00:39:21 +0300 Subject: [PATCH 07/35] k_quants: WIP super-blocks with 64 weights Q6_K working on CUDA. Cannot make it run quite as gast as with super-blocks with 256 weigths: 8% slower on 4080, 20% slower on the 1660 (but there we fit 1 less layer on the GPU because pf the larger model size), so some fraction of these 20% is due to that, --- CMakeLists.txt | 16 ++++++------ ggml-cuda.cu | 70 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6aae8e1661f3d..ffda74a700bef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -226,6 +226,14 @@ if (LLAMA_BLAS) endif() endif() +if (LLAMA_K_QUANTS) + set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h) + add_compile_definitions(GGML_USE_K_QUANTS) + if (LLAMA_QKK_64) + add_compile_definitions(GGML_QKK_64) + endif() +endif() + if (LLAMA_CUBLAS) cmake_minimum_required(VERSION 3.17) @@ -290,14 +298,6 @@ if (LLAMA_METAL) ) endif() -if (LLAMA_K_QUANTS) - set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h) - add_compile_definitions(GGML_USE_K_QUANTS) - if (LLAMA_QKK_64) - add_compile_definitions(GGML_QKK_64) - endif() -endif() - if (LLAMA_CLBLAST) find_package(CLBlast) if (CLBlast_FOUND) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 010682edb703c..ddef2400a3fdc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -117,7 +117,14 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo //================================= k-quants +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else #define QK_K 256 +#define K_SCALE_SIZE 12 +#endif + typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits @@ -133,7 +140,7 @@ typedef struct { uint8_t scales[3*QK_K/64]; half d; } block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); +//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); typedef struct { half d; // super-block scale for quantized scales @@ -141,7 +148,7 @@ typedef struct { uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +//static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); typedef struct { half d; // super-block scale for quantized scales @@ -150,7 +157,7 @@ typedef struct { uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +//static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits @@ -482,6 +489,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { const block_q6_K * x = (const block_q6_K *) vx; const int i = blockIdx.x; +#if QK_K == 256 // assume 64 threads - this is very slightly better than the one below const int tid = threadIdx.x; @@ -501,6 +509,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +#else + + // assume 32 threads + const int tid = threadIdx.x; + const int ip = tid/16; // 0 or 1 + const int il = tid - 16*ip; // 0...15 + + float * y = yy + i*QK_K + 16*ip + il; + + const float d = x[i].d; + + const uint8_t ql = x[i].ql[16*ip + il]; + const uint8_t qh = x[i].qh[il] >> (2*ip); + const int8_t * sc = x[i].scales; + + y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); +#endif } static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { @@ -820,6 +846,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float const block_q6_K * x = (const block_q6_K *)vx + ib0; +#if QK_K == 256 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 @@ -874,6 +902,38 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float } +#else + + const int tid = threadIdx.x/4; // 0...7 + const int ix = threadIdx.x%4; // 0...3 + + const int step = tid; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 4) { + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = x[i].ql + step; + const uint8_t * qh = x[i].qh + step; + const int8_t * s = x[i].scales; + + const float d = x[i+0].d; + + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[0] & 0x03) << 4)) - 32) + + y[ 8] * s[0] * d * ((int8_t)((ql[ 8] & 0xF) | ((qh[8] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[0] & 0x0c) << 2)) - 32) + + y[24] * s[1] * d * ((int8_t)((ql[24] & 0xF) | ((qh[8] & 0x0c) << 2)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[0] & 0x30) >> 0)) - 32) + + y[40] * s[2] * d * ((int8_t)((ql[ 8] >> 4) | ((qh[8] & 0x30) >> 0)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[16] >> 4) | ((qh[0] & 0xc0) >> 2)) - 32) + + y[56] * s[3] * d * ((int8_t)((ql[24] >> 4) | ((qh[8] & 0xc0) >> 2)) - 32); + tmp += sum; + + } + +#endif + // sum up partial sums and write back result __syncthreads(); #pragma unroll @@ -1272,7 +1332,11 @@ static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); +#else + dequantize_block_q6_K<<>>(vx, y); +#endif } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { From 5aae4b8d4ff50fe11065835953b1b80a4b9fdbcc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 09:40:33 +0300 Subject: [PATCH 08/35] k_quants: WIP super-blocks with 64 weights Q4_K working on CUDA. ~10% slower on GTX-1660, 16% slower on 4080. --- ggml-cuda.cu | 91 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ddef2400a3fdc..2e86226fb5688 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -142,13 +142,21 @@ typedef struct { } block_q3_K; //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); +#ifdef GGML_QKK_64 +typedef struct { + half d[2*QK_K/32]; // super-block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding"); +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -//static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif typedef struct { half d; // super-block scale for quantized scales @@ -420,13 +428,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { const int i = blockIdx.x; - //// assume 64 threads - this is very slightly better than the one below - //const int tid = threadIdx.x; - //const int il = tid/16; - //const int ir = tid%16; - //const int is = 2*il; - //const int n = 2; - +#if QK_K == 256 // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; @@ -450,6 +452,13 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { y[l + 0] = d1 * (q[l] & 0xF) - m1; y[l +32] = d2 * (q[l] >> 4) - m2; } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + float * y = yy + i*QK_K; + y[tid+ 0] = (float)x[i].d[0] * (q[tid] & 0xF) - (float)x[i].d[1]; + y[tid+32] = (float)x[i].d[2] * (q[tid] >> 4) - (float)x[i].d[3]; +#endif } static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { @@ -681,10 +690,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; @@ -693,6 +698,13 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + const block_q4_K * x = (const block_q4_K *)vx + ib0; + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 const int il = tid/step; // 0...3 @@ -709,8 +721,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; - const block_q4_K * x = (const block_q4_K *)vx + ib0; - float tmp = 0; // partial sum for thread in warp for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { @@ -739,6 +749,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; } +#else + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const half2 * d = (const half2 *)x[i].d; + float2 df1 = __half22float2(d[0]); + float2 df2 = __half22float2(d[1]); + //float2 sum = {0.f, 0.f}; + //float smin = 0; + //for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + // sum.x += y[j+ 0] * (q[j+ 0] & 0xF) + y[j+16] * (q[j+16] & 0xF); + // sum.y += y[j+32] * (q[j+ 0] >> 4) + y[j+48] * (q[j+16] >> 4); + // smin += df1.y * (y[j] + y[j+16]) + df2.y * (y[j+32] + y[j+48]); + //} + //tmp += df1.x * sum.x + df2.x * sum.y - smin; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y) + + y[j+16] * (df1.x * (q[j+16] & 0xF) - df1.y) + + y[j+32] * (df2.x * (q[j+ 0] >> 4) - df2.y) + + y[j+48] * (df2.x * (q[j+16] >> 4) - df2.y); + } + tmp += sum; + } + +#endif // sum up partial sums and write back result __syncthreads(); @@ -904,14 +944,14 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float #else - const int tid = threadIdx.x/4; // 0...7 - const int ix = threadIdx.x%4; // 0...3 + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 - const int step = tid; + const int step = tid * K_QUANTS_PER_ITERATION; float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 4) { + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { const float * y = yy + i * QK_K + step; const uint8_t * ql = x[i].ql + step; @@ -920,14 +960,13 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float const float d = x[i+0].d; - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[0] & 0x03) << 4)) - 32) - + y[ 8] * s[0] * d * ((int8_t)((ql[ 8] & 0xF) | ((qh[8] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[0] & 0x0c) << 2)) - 32) - + y[24] * s[1] * d * ((int8_t)((ql[24] & 0xF) | ((qh[8] & 0x0c) << 2)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[0] & 0x30) >> 0)) - 32) - + y[40] * s[2] * d * ((int8_t)((ql[ 8] >> 4) | ((qh[8] & 0x30) >> 0)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[16] >> 4) | ((qh[0] & 0xc0) >> 2)) - 32) - + y[56] * s[3] * d * ((int8_t)((ql[24] >> 4) | ((qh[8] & 0xc0) >> 2)) - 32); + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } tmp += sum; } From 41e46ec1c208e404499d34cf4e4df94637e897e0 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 11:39:00 +0300 Subject: [PATCH 09/35] k_quants: WIP super-blocks with 64 weights Q2_K working on CUDA. ~3% slower on GTX-1660, 10% slower on 4080. --- ggml-cuda.cu | 57 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2e86226fb5688..a5fe7e572bf01 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -364,13 +364,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { const int i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + const int tid = threadIdx.x; +#if QK_K == 256 const int n = tid/32; const int l = tid - 32*n; const int is = 8*n + l/16; - const block_q2_K * x = (const block_q2_K *) vx; - const uint8_t q = x[i].qs[32*n + l]; float * y = yy + i*QK_K + 128*n; @@ -380,6 +381,16 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +#else + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const uint8_t q = x[i].qs[il] >> (2*is); + float * y = yy + i*QK_K + 16*is + il; + float dall = x[i].d; + float dmin = x[i].dmin; + y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); +#endif } @@ -550,6 +561,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float const block_q2_K * x = (const block_q2_K *)vx + ib0; + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -563,8 +577,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float const int s_offset = 8*im; const int y_offset = 128*im + l0; - float tmp = 0; // partial sum for thread in warp - uint32_t aux[4]; const uint8_t * d = (const uint8_t *)aux; const uint8_t * m = (const uint8_t *)(aux + 2); @@ -600,6 +612,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float tmp += dall * sum1 - dmin * sum2; } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; + + uint32_t uaux[2]; + const uint8_t * d = (const uint8_t *)uaux; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint32_t * s = (const uint32_t *)x[i].scales; + + uaux[0] = s[0] & 0x0f0f0f0f; + uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; + + const half2 * dh = (const half2 *)&x[i].d; + + const float2 dall = __half22float2(dh[0]); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t ql = q[l]; + sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + + y[l+16] * d[1] * ((ql >> 2) & 3) + + y[l+32] * d[2] * ((ql >> 4) & 3) + + y[l+48] * d[3] * ((ql >> 6) & 3); + sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; + } + tmp += dall.x * sum1 - dall.y * sum2; + } +#endif // sum up partial sums and write back result __syncthreads(); @@ -1351,7 +1396,11 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); +#else + dequantize_block_q2_K<<>>(vx, y); +#endif } static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { From 460dd841b1920d2e4740a5e822d1a77c27855666 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 12:51:18 +0300 Subject: [PATCH 10/35] k_quants: WIP super-blocks with 64 weights Q3_K working on CUDA. --- ggml-cuda.cu | 94 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 20 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a5fe7e572bf01..486bc866d10cd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -125,7 +125,6 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo #define K_SCALE_SIZE 12 #endif - typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants @@ -135,12 +134,16 @@ typedef struct { static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); typedef struct { - uint8_t hmask[QK_K/8]; - uint8_t qs[QK_K/4]; // nibbles / quants - uint8_t scales[3*QK_K/64]; - half d; + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#ifdef GGML_QKK_64 + int8_t scales[K_SCALE_SIZE]; // scales, quantized with 8 bits +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale } block_q3_K; -//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); #ifdef GGML_QKK_64 typedef struct { @@ -396,16 +399,17 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { - int r = threadIdx.x/4; - int i = blockIdx.x; - int tid = r/2; - int is0 = r%2; - int l0 = 16*is0 + 4*(threadIdx.x%4); - int n = tid / 4; - int j = tid - 4*n; - + const int i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; +#if QK_K == 256 + const int r = threadIdx.x/4; + const int tid = r/2; + const int is0 = r%2; + const int l0 = 16*is0 + 4*(threadIdx.x%4); + const int n = tid / 4; + const int j = tid - 4*n; + uint8_t m = 1 << (4*n + j); int is = 8*n + 2*j + is0; int shift = 2*j; @@ -422,6 +426,22 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { const uint8_t * hm = x[i].hmask; for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +#else + const int tid = threadIdx.x; + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const int im = il/8; // 0...1 + const int in = il%8; // 0...7 + + float * y = yy + i*QK_K + 16*is + il; + + const uint8_t q = x[i].qs[il] >> (2*is); + const uint8_t h = x[i].hmask[in] >> (2*is + im); + const float d = (float)x[i].d; + + y[ 0] = d * x[i].scales[is+0] * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * x[i].scales[is+2] * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); +#endif } @@ -660,9 +680,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -671,6 +688,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float const block_q3_K * x = (const block_q3_K *)vx + ib0; + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -690,8 +714,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float const uint16_t s_shift = 4*im; - float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const float * y = yy + i * QK_K + y_offset; @@ -720,6 +742,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float tmp += d * sum; } +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const int8_t * s = x[i].scales; + + const float dall = (float)x[i].d; + + float sum = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t hl = x[i].hmask[im+l] >> in; + const uint8_t ql = q[l]; + sum += y[l+ 0] * dall * s[0] * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) + + y[l+16] * dall * s[1] * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + + y[l+32] * dall * s[2] * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + + y[l+48] * dall * s[3] * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + } + tmp += sum; + } +#endif // sum up partial sums and write back result __syncthreads(); @@ -728,7 +778,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } @@ -1405,7 +1455,11 @@ static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); +#else + dequantize_block_q3_K<<>>(vx, y); +#endif } static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { From 3bd9ae79d8679e7047466c875dfc3c69e03a42e3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 13:31:54 +0300 Subject: [PATCH 11/35] k_quants: WIP super-blocks with 64 weights Q5_K working on CUDA, and with this CUDA is done. --- ggml-cuda.cu | 86 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 64 insertions(+), 22 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 486bc866d10cd..ec66c884a7001 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -161,14 +161,23 @@ typedef struct { static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); #endif +#ifdef GGML_QKK_64 typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + half d[2*QK_K/32]; // super-block scales/mins uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -//static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits @@ -445,6 +454,7 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { } +#if QK_K == 256 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { d = q[j] & 63; m = q[j + 4] & 63; @@ -453,6 +463,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); } } +#endif static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { const block_q4_K * x = (const block_q4_K *) vx; @@ -497,6 +508,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { const int i = blockIdx.x; +#if QK_K == 256 // assume 64 threads - this is very slightly better than the one below const int tid = threadIdx.x; const int il = tid/16; // il is in 0...3 @@ -523,6 +535,16 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { hm <<= 1; y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + const int tid = threadIdx.x; + const uint8_t q = x[i].qs[tid]; + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const uint8_t h = x[i].qh[in] >> im; + float * y = yy + i*QK_K + tid; + y[ 0] = (float)x[i].d[0] * ((q & 0xF) + ((h >> 0) & 1 ? 16 : 0)) - (float)x[i].d[1]; + y[32] = (float)x[i].d[2] * ((q >> 4) + ((h >> 4) & 1 ? 16 : 0)) - (float)x[i].d[3]; +#endif } static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { @@ -855,14 +877,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float const half2 * d = (const half2 *)x[i].d; float2 df1 = __half22float2(d[0]); float2 df2 = __half22float2(d[1]); - //float2 sum = {0.f, 0.f}; - //float smin = 0; - //for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - // sum.x += y[j+ 0] * (q[j+ 0] & 0xF) + y[j+16] * (q[j+16] & 0xF); - // sum.y += y[j+32] * (q[j+ 0] >> 4) + y[j+48] * (q[j+16] >> 4); - // smin += df1.y * (y[j] + y[j+16]) + df2.y * (y[j+32] + y[j+48]); - //} - //tmp += df1.x * sum.x + df2.x * sum.y - smin; float sum = 0.f; for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y) @@ -889,15 +903,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - //const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = threadIdx.x/2; // 0...15 const int ix = threadIdx.x%2; @@ -918,10 +936,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; - const block_q5_K * x = (const block_q5_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2) { const uint8_t * ql1 = x[i].qs + q_offset; @@ -954,8 +968,32 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; } tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; + const int im = step/8; + const int in = step%8; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const half2 * d = (const half2 *)x[i].d; + float2 df1 = __half22float2(d[0]); + float2 df2 = __half22float2(d[1]); + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + const uint8_t h = x[i].qh[in+j] >> im; + sum += y[j+ 0] * (df1.x * ((q[j+ 0] & 0xF) + (((h >> 0) & 1) << 4)) - df1.y) + + y[j+16] * (df1.x * ((q[j+16] & 0xF) + (((h >> 2) & 1) << 4)) - df1.y) + + y[j+32] * (df2.x * ((q[j+ 0] >> 4) + (((h >> 4) & 1) << 4)) - df2.y) + + y[j+48] * (df2.x * ((q[j+16] >> 4) + (((h >> 6) & 1) << 4)) - df2.y); + } + tmp += sum; } +#endif // sum up partial sums and write back result __syncthreads(); @@ -964,7 +1002,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } @@ -1469,7 +1507,11 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); +#else + dequantize_block_q5_K<<>>(vx, y); +#endif } static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { From 03f30c8eca735ca3656a2b53a84da53f688b2e6e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 15:10:41 +0300 Subject: [PATCH 12/35] k_quants: WIP super-blocks with 64 weights Q6_K working on ARM_NEON --- k_quants.c | 141 ++++++++++++++--------------------------------------- 1 file changed, 37 insertions(+), 104 deletions(-) diff --git a/k_quants.c b/k_quants.c index 86c224fc97345..a3898c603218d 100644 --- a/k_quants.c +++ b/k_quants.c @@ -261,6 +261,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t return scale; } +#if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { if (j < 4) { *d = q[j] & 63; *m = q[j + 4] & 63; @@ -269,6 +270,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); } } +#endif //========================- 2-bit (de)-quantization @@ -1895,7 +1897,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; -#ifdef __ARM_NEON +#ifdef z__ARM_NEON uint32_t aux[3]; uint32_t utmp[4]; @@ -2361,7 +2363,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; -#ifdef __ARM_NEON +#ifdef z__ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); #ifdef __ARM_FEATURE_DOTPROD @@ -2779,7 +2781,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; -#ifdef __ARM_NEON +#ifdef z__ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); @@ -3243,7 +3245,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t m4b = vdupq_n_u8(0xF); const int32x4_t vzero = vdupq_n_s32(0); - //const int8x16_t m32s = vdupq_n_s8(32); + const int8x16_t m32s = vdupq_n_s8(32); const uint8x16_t mone = vdupq_n_u8(3); @@ -3252,7 +3254,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { - const float d_all = ggml_fp16_to_fp32(x[i].d); + const float d_all = (float)x[i].d; const uint8_t * restrict q6 = x[i].ql; const uint8_t * restrict qh = x[i].qh; @@ -3260,116 +3262,47 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri const int8_t * restrict scale = x[i].scales; - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int8x16_t scales = vld1q_s8(scale); - const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; - - const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), - vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), - vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), - vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); - int32_t isum_mins = vaddvq_s32(prod); - int32_t isum = 0; - for (int j = 0; j < QK_K/128; ++j) { - - uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; - uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; - int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; - - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 2); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); - -#if defined(__ARM_FEATURE_DOTPROD) - - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; - -#else - - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; -#endif + uint8x16_t qhbits = vld1q_u8(qh); + uint8x16x2_t q6bits = vld1q_u8_x2(q6); + int8x16x4_t q8bytes = vld1q_s8_x4(q8); - q8bytes = vld1q_s8_x4(q8); q8 += 64; + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits, 2); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 4); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[0], 4); - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[0], 6); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); + q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); #if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; - - //for (int l = 0; l < 4; ++l) { - // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); - // isum += vaddvq_s32(p) * *scale++; - //} + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; #else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; #endif - } - //sum += isum * d_all * y[i].d; - sum += d_all * y[i].d * (isum - 32 * isum_mins); + sum += isum * d_all * y[i].d; } *s = sum; From cda47a6b2f273b54404468f494c12cbb7af34f01 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 16:31:51 +0300 Subject: [PATCH 13/35] k_quants: WIP super-blocks with 64 weights Q4_K working on ARM_NEON, but quite a bit slower than 256 weights --- k_quants.c | 97 +++++++++++++++++++++--------------------------------- 1 file changed, 38 insertions(+), 59 deletions(-) diff --git a/k_quants.c b/k_quants.c index a3898c603218d..a2e2793f4bd54 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2363,92 +2363,71 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; -#ifdef z__ARM_NEON +#ifdef __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); + #ifdef __ARM_FEATURE_DOTPROD const int32x4_t mzero = vdupq_n_s32(0); #endif - int8x16x2_t q4bytes; - int8x16x2_t q8bytes; - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - - const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)}; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; + int8x16x2_t q4bytes; + int8x16x4_t q8bytes; - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - sumf -= dmin * vaddvq_s32(prod); + float sum_mins = 0.f; - const uint8_t * scales = (const uint8_t *)utmp; + for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - //int32x4_t isum = mzero; - - int32_t sumi1 = 0; - int32_t sumi2 = 0; - - for (int j = 0; j < QK_K/64; ++j) { + const float32x4_t dsc = vcvt_f32_f16(vld1_f16(x[i].d)); + float summ = vgetq_lane_f32(dsc, 1) * (y[i].bsums[0] + y[i].bsums[1]) + + vgetq_lane_f32(dsc, 3) * (y[i].bsums[2] + y[i].bsums[3]); + sum_mins += y[i].d * summ; - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; + const uint8x16x2_t q4bits = vld1q_u8_x2(q4); #ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + q8bytes = vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + const float sumf1 = vaddvq_s32(p1) * vgetq_lane_f32(dsc, 0); - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); + const float sumf2 = vaddvq_s32(p2) * vgetq_lane_f32(dsc, 2); - sumi2 += vaddvq_s32(p2) * scales[2*j+1]; #else - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; - - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; + q8bytes = vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + float sumf1 = vaddvq_s16(vaddq_s16(p0, p1)) * vgetq_lane_f32(dsc, 0); + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); + float sumf2 = vaddvq_s16(vaddq_s16(p2, p3)) * vgetq_lane_f32(dsc, 2); #endif - } - - sumf += d * (sumi1 + sumi2); + sumf += y[i].d * (sumf1 + sumf2); } - *s = sumf; + *s = sumf - sum_mins; #elif defined __AVX2__ From 80c75fe821c417f15212592c862b0deca1e1e11c Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 18:16:08 +0300 Subject: [PATCH 14/35] k_quants: WIP super-blocks with 64 weights Q2_K working on ARM_NEON, but quite a bit slower than 256 weights --- k_quants.c | 96 +++++++++++++++++++++--------------------------------- 1 file changed, 38 insertions(+), 58 deletions(-) diff --git a/k_quants.c b/k_quants.c index a2e2793f4bd54..f3b49dd01d4fc 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1417,81 +1417,61 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri #ifdef __ARM_NEON const uint8x16_t m3 = vdupq_n_u8(0x3); - const uint8x16_t m4 = vdupq_n_u8(0xF); const int32x4_t vzero = vdupq_n_s32(0); - int8x16x2_t q2bytes; - uint8_t aux[16]; + int8x16x4_t q2bytes; + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; float sum = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + const float d = y[i].d * (float)x[i].d; + const float dmin = -y[i].d * (float)x[i].dmin; const uint8_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const uint8_t * restrict sc = x[i].scales; - - const uint8x16_t mins_and_scales = vld1q_u8(sc); - const uint8x16_t scales = vandq_u8(mins_and_scales, m4); - vst1q_u8(aux, scales); - - const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; - const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), - vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); - const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), - vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); - sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); - - int isum = 0; - int is = 0; - -// We use this macro instead of a function call because for some reason -// the code runs 2-3% slower, even if the function is declared inline -#if defined(__ARM_FEATURE_DOTPROD) -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; -#else -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - {\ - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ - isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ - } -#endif + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; -#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ - q8bytes = vld1q_s8_x2(q8); q8 += 32;\ - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ - MULTIPLY_ACCUM_WITH_SCALE((index)); + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); - for (int j = 0; j < QK_K/128; ++j) { + int isum1 = 0, isum2 = 0; - const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; + const uint8x16_t q2bits = vld1q_u8(q2); - int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); - MULTIPLY_ACCUM_WITH_SCALE(0); + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); - - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); + q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); + q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); - - is += 8; - } - sum += d * isum; +#if defined(__ARM_FEATURE_DOTPROD) + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; +#else + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum1 += vaddvq_s16(p1) * scales[0]; + isum2 += vaddvq_s16(p2) * scales[1]; + + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum1 += vaddvq_s16(p3) * scales[2]; + isum2 += vaddvq_s16(p4) * scales[3]; +#endif + sum += d * (isum1 + isum2); } From 2b2a13c4f9d3570b039ccde3ed75cf3c64f5be7e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 19:41:28 +0300 Subject: [PATCH 15/35] k_quants: WIP super-blocks with 64 weights Q3_K working on ARM_NEON, but quite a bit slower than 256 weights. --- k_quants.c | 122 +++++++++++++++-------------------------------------- 1 file changed, 34 insertions(+), 88 deletions(-) diff --git a/k_quants.c b/k_quants.c index f3b49dd01d4fc..023dec1fcf9be 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1877,21 +1877,14 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; -#ifdef z__ARM_NEON - - uint32_t aux[3]; - uint32_t utmp[4]; +#ifdef __ARM_NEON - const uint8x16_t m3b = vdupq_n_u8(0x3); #ifdef __ARM_FEATURE_DOTPROD const int32x4_t vzero = vdupq_n_s32(0); #endif - const uint8x16_t m0 = vdupq_n_u8(1); - const uint8x16_t m1 = vshlq_n_u8(m0, 1); - const uint8x16_t m2 = vshlq_n_u8(m0, 2); - const uint8x16_t m3 = vshlq_n_u8(m0, 3); - const int8_t m32 = 32; + const uint8x16_t m3b = vdupq_n_u8(0x3); + const uint8x16_t m0 = vdupq_n_u8(1); int8x16x4_t q3bytes; @@ -1899,96 +1892,49 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - uint8x16x2_t qhbits = vld1q_u8_x2(qh); - uint8x16x4_t q3h; - int32_t isum = 0; - - // Set up scales - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + const uint8x8_t hbits = vld1_u8(x[i].hmask); + const uint8x16_t q3bits = vld1q_u8(x[i].qs); + const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= m32; + const int8_t * restrict scales = x[i].scales; + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[1] * y[i].bsums[1] + scales[2] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); - for (int j = 0; j < QK_K/128; ++j) { + const float d = y[i].d * (float)x[i].d; - const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; - const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; - const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; + q3h.val[0] = vandq_u8(m0, vcombine_u8(hbits, vshr_n_u8(hbits, 1))); + q3h.val[1] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 2), vshr_n_u8(hbits, 3))); + q3h.val[2] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 4), vshr_n_u8(hbits, 5))); + q3h.val[3] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 6), vshr_n_u8(hbits, 7))); - q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); - q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); - q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); - q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + q3h.val[0] = vshlq_n_u8(q3h.val[0], 2); + q3h.val[1] = vshlq_n_u8(q3h.val[1], 2); + q3h.val[2] = vshlq_n_u8(q3h.val[2], 2); + q3h.val[3] = vshlq_n_u8(q3h.val[3], 2); - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + q3bytes.val[0] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); + q3bytes.val[1] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); + q3bytes.val[2] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); + q3bytes.val[3] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 6), m3b), q3h.val[3])); #if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; #else - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[1] + vaddvq_s16(p2) * scales[2] + vaddvq_s16(p3) * scales[3]; #endif - scale += 4; - q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); - q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); - q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); - q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; -#else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif - scale += 4; - - if (j == 0) { - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); - } - - } sum += d * isum; } From 9d27d8d0eaa413d16f08cdbbbb158cccc2b66704 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 20:04:18 +0300 Subject: [PATCH 16/35] k_quants: WIP super-blocks with 64 weights Q5_K working on ARM_NEON, but quite a bit slower than 256 weights. With that, we have full support for ARM_NEON, although performance is not quite there. --- k_quants.c | 83 ++++++++++++++++++------------------------------------ 1 file changed, 27 insertions(+), 56 deletions(-) diff --git a/k_quants.c b/k_quants.c index 023dec1fcf9be..373a67e6dfc61 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2686,87 +2686,58 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; -#ifdef z__ARM_NEON +#ifdef __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); const uint8x16_t mone = vdupq_n_u8(1); - const uint8x16_t mtwo = vdupq_n_u8(2); int8x16x4_t q5bytes; + uint8x16x4_t q5h; float sumf = 0; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - int32_t sumi_mins = vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; + sumf -= y[i].d * ((float)x[i].d[1] * (y[i].bsums[0] + y[i].bsums[1]) + (float)x[i].d[3] * (y[i].bsums[2] + y[i].bsums[3])); const uint8_t * restrict q5 = x[i].qs; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; - uint8x16x2_t qhbits = vld1q_u8_x2(qh); - - uint8x16x4_t q5h; - - int32_t sumi = 0; + const uint8x8_t qhbits = vld1_u8(qh); - for (int j = 0; j < QK_K/64; ++j) { - - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + const uint8x16x2_t q5bits = vld1q_u8_x2(q5); + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); - q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); - q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + q5h.val[0] = vandq_u8(mone, vcombine_u8(qhbits, vshr_n_u8(qhbits, 1))); + q5h.val[1] = vandq_u8(mone, vcombine_u8(vshr_n_u8(qhbits, 2), vshr_n_u8(qhbits, 3))); + q5h.val[2] = vandq_u8(mone, vcombine_u8(vshr_n_u8(qhbits, 4), vshr_n_u8(qhbits, 5))); + q5h.val[3] = vandq_u8(mone, vcombine_u8(vshr_n_u8(qhbits, 6), vshr_n_u8(qhbits, 7))); - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), vshlq_n_u8(q5h.val[0], 4))); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), vshlq_n_u8(q5h.val[1], 4))); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), vshlq_n_u8(q5h.val[2], 4))); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), vshlq_n_u8(q5h.val[3], 4))); #if defined(__ARM_FEATURE_DOTPROD) - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; + sumf += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * y[i].d * (float)x[i].d[0]; + sumf += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * y[i].d * (float)x[i].d[2]; #else - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; - - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumf += vaddvq_s16(vaddq_s16(p0, p1)) * y[i].d * (float)x[i].d[0]; + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumf += vaddvq_s16(vaddq_s16(p2, p3)) * y[i].d * (float)x[i].d[2]; #endif - } - - sumf += d * sumi - dmin * sumi_mins; } From 2ff543c1476cfb1edbefdee3725f5de603934cf3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 22 Jun 2023 23:15:41 +0300 Subject: [PATCH 17/35] k_quants: WIP super-blocks with 64 weights Slightly more efficient Q3_K and Q5_K --- k_quants.c | 52 ++++++++++++++++++++++++---------------------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/k_quants.c b/k_quants.c index 373a67e6dfc61..3c39285902693 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1884,7 +1884,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #endif const uint8x16_t m3b = vdupq_n_u8(0x3); - const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t mh = vdupq_n_u8(4); int8x16x4_t q3bytes; @@ -1903,20 +1903,15 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const float d = y[i].d * (float)x[i].d; - q3h.val[0] = vandq_u8(m0, vcombine_u8(hbits, vshr_n_u8(hbits, 1))); - q3h.val[1] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 2), vshr_n_u8(hbits, 3))); - q3h.val[2] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 4), vshr_n_u8(hbits, 5))); - q3h.val[3] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 6), vshr_n_u8(hbits, 7))); + q3h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(hbits, 2), vshl_n_u8(hbits, 1))); + q3h.val[1] = vandq_u8(mh, vcombine_u8(hbits, vshr_n_u8(hbits, 1))); + q3h.val[2] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 2), vshr_n_u8(hbits, 3))); + q3h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 4), vshr_n_u8(hbits, 5))); - q3h.val[0] = vshlq_n_u8(q3h.val[0], 2); - q3h.val[1] = vshlq_n_u8(q3h.val[1], 2); - q3h.val[2] = vshlq_n_u8(q3h.val[2], 2); - q3h.val[3] = vshlq_n_u8(q3h.val[3], 2); - - q3bytes.val[0] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); - q3bytes.val[1] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); - q3bytes.val[2] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); - q3bytes.val[3] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 6), m3b), q3h.val[3])); + q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); + q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); + q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); + q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); #if defined(__ARM_FEATURE_DOTPROD) isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; @@ -2690,7 +2685,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); - const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mh = vdupq_n_u8(16); int8x16x4_t q5bytes; uint8x16x4_t q5h; @@ -2699,7 +2694,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { - sumf -= y[i].d * ((float)x[i].d[1] * (y[i].bsums[0] + y[i].bsums[1]) + (float)x[i].d[3] * (y[i].bsums[2] + y[i].bsums[3])); + float sumb = -((float)x[i].d[1] * (y[i].bsums[0] + y[i].bsums[1]) + (float)x[i].d[3] * (y[i].bsums[2] + y[i].bsums[3])); const uint8_t * restrict q5 = x[i].qs; const uint8_t * restrict qh = x[i].qh; @@ -2710,34 +2705,35 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16x2_t q5bits = vld1q_u8_x2(q5); const int8x16x4_t q8bytes = vld1q_s8_x4(q8); - q5h.val[0] = vandq_u8(mone, vcombine_u8(qhbits, vshr_n_u8(qhbits, 1))); - q5h.val[1] = vandq_u8(mone, vcombine_u8(vshr_n_u8(qhbits, 2), vshr_n_u8(qhbits, 3))); - q5h.val[2] = vandq_u8(mone, vcombine_u8(vshr_n_u8(qhbits, 4), vshr_n_u8(qhbits, 5))); - q5h.val[3] = vandq_u8(mone, vcombine_u8(vshr_n_u8(qhbits, 6), vshr_n_u8(qhbits, 7))); + q5h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 4), vshl_n_u8(qhbits, 3))); + q5h.val[1] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 2), vshl_n_u8(qhbits, 1))); + q5h.val[2] = vandq_u8(mh, vcombine_u8(qhbits, vshr_n_u8(qhbits, 1))); + q5h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(qhbits, 2), vshr_n_u8(qhbits, 3))); - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), vshlq_n_u8(q5h.val[0], 4))); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), vshlq_n_u8(q5h.val[1], 4))); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), vshlq_n_u8(q5h.val[2], 4))); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), vshlq_n_u8(q5h.val[3], 4))); + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); #if defined(__ARM_FEATURE_DOTPROD) - sumf += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * y[i].d * (float)x[i].d[0]; - sumf += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * y[i].d * (float)x[i].d[2]; + sumb += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * (float)x[i].d[0]; + sumb += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * (float)x[i].d[2]; #else const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumf += vaddvq_s16(vaddq_s16(p0, p1)) * y[i].d * (float)x[i].d[0]; + sumb += vaddvq_s16(vaddq_s16(p0, p1)) * (float)x[i].d[0]; const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumf += vaddvq_s16(vaddq_s16(p2, p3)) * y[i].d * (float)x[i].d[2]; + sumb += vaddvq_s16(vaddq_s16(p2, p3)) * (float)x[i].d[2]; #endif + sumf += y[i].d * sumb; } From d92c5a9e2978bcd18cde42875f92abe8b94cfd7b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 08:43:04 +0300 Subject: [PATCH 18/35] k_quants: WIP super-blocks with 64 weights Another small improvement for Q3_K and Q5_K on ARM_NEON --- k_quants.c | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/k_quants.c b/k_quants.c index 3c39285902693..b43286f4c74a1 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1903,10 +1903,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const float d = y[i].d * (float)x[i].d; - q3h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(hbits, 2), vshl_n_u8(hbits, 1))); - q3h.val[1] = vandq_u8(mh, vcombine_u8(hbits, vshr_n_u8(hbits, 1))); - q3h.val[2] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 2), vshr_n_u8(hbits, 3))); - q3h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 4), vshr_n_u8(hbits, 5))); + const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); + q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q3h.val[1] = vandq_u8(mh, htmp); + q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); + q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); @@ -2705,10 +2706,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16x2_t q5bits = vld1q_u8_x2(q5); const int8x16x4_t q8bytes = vld1q_s8_x4(q8); - q5h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 4), vshl_n_u8(qhbits, 3))); - q5h.val[1] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 2), vshl_n_u8(qhbits, 1))); - q5h.val[2] = vandq_u8(mh, vcombine_u8(qhbits, vshr_n_u8(qhbits, 1))); - q5h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(qhbits, 2), vshr_n_u8(qhbits, 3))); + const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); + q5h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vandq_u8(mh, htmp); + q5h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); From fae24afd011c5634f0c5698225968796dfd21319 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 10:11:35 +0300 Subject: [PATCH 19/35] k_quants: WIP super-blocks with 64 weights Yet another speedup for Q5_K on ARM_NEON. We are now within 10% of the QK_K = 256 version. --- k_quants.c | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/k_quants.c b/k_quants.c index b43286f4c74a1..b05ef45b2e68e 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2693,9 +2693,19 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri float sumf = 0; + float32x4_t acc1 = vdupq_n_f32(0.f); + float32x4_t acc2 = vdupq_n_f32(0.f); + for (int i = 0; i < nb; ++i) { - float sumb = -((float)x[i].d[1] * (y[i].bsums[0] + y[i].bsums[1]) + (float)x[i].d[3] * (y[i].bsums[2] + y[i].bsums[3])); + const float16x4_t s16 = vld1_f16(x[i].d); + const float32x4_t s32 = vmulq_n_f32(vcvt_f32_f16(s16), y[i].d); + //const int16x4_t bi16 = vld1_s16(y[i].bsums); + //const int32x4_t bi32 = vmovl_s16(vpadd_s16(bi16, bi16)); + //const float32x4_t bf32 = vcvtq_f32_s32(bi32); + //sumf -= (vgetq_lane_f32(s32, 1) * vgetq_lane_f32(bf32, 0) + vgetq_lane_f32(s32, 3) * vgetq_lane_f32(bf32, 1)); + // The above is slightly slower than just this: + sumf -= (vgetq_lane_f32(s32, 1) * (y[i].bsums[0] + y[i].bsums[1]) + vgetq_lane_f32(s32, 3) * (y[i].bsums[2] + y[i].bsums[3])); const uint8_t * restrict q5 = x[i].qs; const uint8_t * restrict qh = x[i].qh; @@ -2719,27 +2729,32 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri #if defined(__ARM_FEATURE_DOTPROD) - sumb += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * (float)x[i].d[0]; - sumb += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * (float)x[i].d[2]; + acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])), + vgetq_lane_f32(s32, 0)); + acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])), + vgetq_lane_f32(s32, 2)); #else const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumb += vaddvq_s16(vaddq_s16(p0, p1)) * (float)x[i].d[0]; + const int16x8_t p01_16 = vaddq_s16(p0, p1); + const int32x4_t p01_32 = vaddq_s32(vmovl_s16(vget_low_s16(p01_16)), vmovl_s16(vget_high_s16(p01_16))); + acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(p01_32), vgetq_lane_f32(s32, 0)); const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumb += vaddvq_s16(vaddq_s16(p2, p3)) * (float)x[i].d[2]; + const int16x8_t p02_16 = vaddq_s16(p2, p3); + const int32x4_t p02_32 = vaddq_s32(vmovl_s16(vget_low_s16(p02_16)), vmovl_s16(vget_high_s16(p02_16))); + acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(p02_32), vgetq_lane_f32(s32, 2)); #endif - sumf += y[i].d * sumb; } - *s = sumf; + *s = vaddvq_f32(vaddq_f32(acc1, acc2)) + sumf; #elif defined __AVX2__ From e1bbcfc5cb8a4ecac40899069db27968050fd110 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 11:32:14 +0300 Subject: [PATCH 20/35] k_quants: WIP super-blocks with 64 weights * We are able to pass preprocessor macros to the Metal compiler * Q6_K works and is actually slightly more efficient than the QK_K = 256 version (25.2 ms vs 25.8 ms) --- ggml-metal.m | 6 ++++++ ggml-metal.metal | 46 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a7e104dc76fca..f84a2c4337b52 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -132,7 +132,13 @@ @implementation GGMLMetalClass exit(1); } +#ifdef GGML_QKK_64 + MTLCompileOptions* options = [MTLCompileOptions new]; + options.preprocessorMacros = @{ @"QK_K" : @(64) }; + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; +#else ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; +#endif if (error) { fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); exit(1); diff --git a/ggml-metal.metal b/ggml-metal.metal index d1e49222db2eb..03f599d806ce6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32( } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { - for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } } @@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32( } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } } @@ -775,7 +775,11 @@ kernel void kernel_cpy_f32_f32( //============================================ k-quants ====================================================== +#ifndef QK_K #define QK_K 256 +#else +static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); +#endif typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits @@ -988,6 +992,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i const float d = x[i].d; +#if QK_K == 256 for (int n = 0; n < QK_K; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; @@ -1005,6 +1010,19 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i qh += 32; sc += 8; } +#else + for (int l = 0; l < 16; ++l) { + const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l+ 0] = d * sc[0] * q1; + y[l+16] = d * sc[1] * q2; + y[l+32] = d * sc[2] * q3; + y[l+48] = d * sc[3] * q4; + } + y += 64; +#endif } } @@ -1528,6 +1546,9 @@ kernel void kernel_mul_mat_q6_k_f32( const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 // Note: we absolutely assume that tptg.y = 16 and QK_K = 256! const int iqs = 16 * tpitg.y; const int ip = iqs / 128; // 0 or 1 @@ -1540,7 +1561,6 @@ kernel void kernel_mul_mat_q6_k_f32( const int q_offset_l = 64*ip + l0; const int q_offset_h = 32*ip + l0; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * ql = x[i].ql + q_offset_l; @@ -1562,6 +1582,26 @@ kernel void kernel_mul_mat_q6_k_f32( sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); } +#else + const int il = 4*tpitg.x; // 0, 4, 8, 12 + + for (int i = tpitg.y; i < nb; i += tptg.y) { + device const float * y = yy + i * QK_K + il; + device const uint8_t * ql = x[i].ql + il; + device const uint8_t * qh = x[i].qh + il; + device const int8_t * s = x[i].scales; + + const float d = x[i].d; + + for (int l = 0; l < 4; ++l) { + sumf += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32) + + y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32) + + y[l+48] * s[3] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32); + } + } + +#endif sum[ith] = sumf; From 167a0bbe3406909fd8ff6dd9845c2d9281a1f728 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 12:34:07 +0300 Subject: [PATCH 21/35] k_quants: WIP super-blocks with 64 weights Q4_K works on Metal and is actually slightly faster than QK_K = 256 (21.95 ms vs 24.0 ms). --- ggml-metal.metal | 65 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 03f599d806ce6..11f390d70e656 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -781,6 +781,12 @@ kernel void kernel_cpy_f32_f32( static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); #endif +#if QK_K == 256 +#define K_SCALE_SIZE 12 +#else +#define K_SCALE_SIZE 4 +#endif + typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants @@ -797,13 +803,19 @@ typedef struct { } block_q3_k; // 110 bytes / block +#if QK_K == 64 +typedef struct { + half4 d; // super-block scales/mins + uint8_t qs[QK_K/2]; // 4-bit quants +} block_q4_k; +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_k; -// 144 bytes / block +#endif typedef struct { half d; // super-block scale for quantized scales @@ -934,10 +946,12 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i for (int i = 0; i < nb; i++) { + device const uint8_t * q = x[i].qs; + +#if QK_K == 256 const float d = x[i].d; const float min = x[i].dmin; - device const uint8_t * q = x[i].qs; device const uint8_t * scales = x[i].scales; int is = 0; @@ -949,6 +963,14 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; q += 32; is += 2; } +#else + float4 d4 = (float4)x[i].d; + for (int l = 0; l < 32; ++l) { + y[l+ 0] = d4[0] * (q[l] & 0xF) - d4[1]; + y[l+32] = d4[2] * (q[l] >> 4) - d4[3]; + } + y += QK_K; +#endif } } @@ -1332,11 +1354,15 @@ kernel void kernel_mul_mat_q4_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; + device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 @@ -1350,11 +1376,8 @@ kernel void kernel_mul_mat_q4_k_f32( const int q_offset = 32*im + l0; const int y_offset = 64*im + l0; - sum[ith] = 0.0f; - uchar2 sc1, sc2, sc3, sc4; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q1 = (x + i)->qs + q_offset; @@ -1383,6 +1406,22 @@ kernel void kernel_mul_mat_q4_k_f32( sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } +#else + const int il = 4*tpitg.x; + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + device const uint8_t * q = x[i].qs + il; + device const float * y = yy + i * QK_K + il; + + const float4 d4 = (float4)x[i].d; + + for (int l = 0; l < 4; ++l) { + sumf += d4[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - d4[1] * (y[l+ 0] + y[l+16]) + + d4[2] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - d4[3] * (y[l+32] + y[l+48]); + } + } +#endif sum[ith] = sumf; @@ -1593,12 +1632,14 @@ kernel void kernel_mul_mat_q6_k_f32( const float d = x[i].d; + float4 sums = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < 4; ++l) { - sumf += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32) - + y[l+16] * s[1] * d * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32) - + y[l+32] * s[2] * d * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32) - + y[l+48] * s[3] * d * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32); + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32); + sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32); + sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32); + sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32); } + sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); } #endif From 6081a65527b2d1972d2c7a3deb564869cc9d3f10 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 13:04:30 +0300 Subject: [PATCH 22/35] k_quants: WIP super-blocks with 64 weights Q2_K works on Metal and is very slightly faster than QK_K = 256 (23.8 ms vs 24.2 ms). --- ggml-metal.metal | 62 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 11f390d70e656..8c9aadf1e4db6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -863,6 +863,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i device const uint8_t * q = x[i].qs; +#if QK_K == 256 int is = 0; float dl, ml; for (int n = 0; n < QK_K; n += 128) { @@ -881,6 +882,19 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i } q += 32; } +#else + float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); + float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); + float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); + float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); + for (int l = 0; l < 16; ++l) { + y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1; + y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2; + y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3; + y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4; + } + y += QK_K; +#endif } } @@ -1153,6 +1167,9 @@ kernel void kernel_mul_mat_q2_k_f32( const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid%4; // 0...3 @@ -1165,9 +1182,6 @@ kernel void kernel_mul_mat_q2_k_f32( const int y_offset = 64*il + n*ir; const int q_offset = 32*ip + n*ir; - sum[ith] = 0.0f; - - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q = x[i].qs + q_offset; @@ -1180,7 +1194,6 @@ kernel void kernel_mul_mat_q2_k_f32( device const float * y = yy + i*QK_K + y_offset; - //float4 s = {0.f, 0.f, 0.f, 0.f}; float2 s = {0.f, 0.f}; float smin = 0; for (int l = 0; l < n; ++l) { @@ -1195,25 +1208,38 @@ kernel void kernel_mul_mat_q2_k_f32( sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin; } - sum[ith] = sumf; +#else + const int il = 4 * tpitg.x; - //int mask1 = (ith%4 == 0); - //int mask2 = (ith%16 == 0); + uint32_t aux[2]; + thread const uint8_t * d = (thread const uint8_t *)aux; + thread const uint8_t * m = (thread const uint8_t *)aux + 4; - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i]; - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i]; - //threadgroup_barrier(mem_flags::mem_threadgroup); - //if (ith == 0) { - // for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - // dst[r1*ne0 + r0] = sum[0]; - //} + for (int i = tpitg.y; i < nb; i += tptg.y) { + + device const uint8_t * q = x[i].qs + il; + device const float * y = yy + i*QK_K + il; + + const float dall = (float)x[i].d; + const float dmin = (float)x[i].dmin; + + device const uint32_t * a = (device const uint32_t *)x[i].scales; + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = (a[0] >> 4) & 0x0f0f0f0f; + + for (int l = 0; l < 4; ++l) { + sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0]) + + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1]) + + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2]) + + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]); + } + } +#endif + + sum[ith] = sumf; // // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. // threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%4 == 0) { From ff83e32c6ab4b6500625df347750c896cffe4bb5 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 13:34:02 +0300 Subject: [PATCH 23/35] k_quants: WIP super-blocks with 64 weights Q3_K works on Metal and is slightly faster than QK_K = 256 (26.6 ms vs 28.3 ms). --- ggml-metal.metal | 88 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 77 insertions(+), 11 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 8c9aadf1e4db6..ccd98956f8de6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -798,10 +798,13 @@ typedef struct { typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits - half d; // super-block scale +#if QK_K == 64 + int8_t scales[K_SCALE_SIZE]; +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale } block_q3_k; -// 110 bytes / block #if QK_K == 64 typedef struct { @@ -903,6 +906,8 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i assert(k % QK_K == 0); const int nb = k / QK_K; +#if QK_K == 256 + const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; @@ -948,8 +953,34 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i } q += 32; } + } +#else + for (int i = 0; i < nb; i++) { + + const float d_all = (float)(x[i].d); + device const uint8_t * q = x[i].qs; + device const uint8_t * hm = x[i].hmask; + + const float d1 = d_all * x[i].scales[0]; + const float d2 = d_all * x[i].scales[1]; + const float d3 = d_all * x[i].scales[2]; + const float d4 = d_all * x[i].scales[3]; + + for (int l = 0; l < 8; ++l) { + uint8_t h = hm[l]; + y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); + y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); + y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); + y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); + y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); + y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); + y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); + y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); + } + y += QK_K; } +#endif } @@ -1286,6 +1317,8 @@ kernel void kernel_mul_mat_q3_k_f32( const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; +#if QK_K == 256 + const int tid = tpitg.y; // expecting 16 const int ip = tid/8; // 0 or 1 const int il = tid/2 - 4*ip; // 0...3 @@ -1339,6 +1372,39 @@ kernel void kernel_mul_mat_q3_k_f32( //sum[ith] = sumf; sum[ith] = sumf1 - 32.f*sumf2; +#else + const int il = 4 * tpitg.x; // 0, 4, 8, 12 + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + float sumf = 0; + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + const float d_all = (float)(x[i].d); + + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].hmask + in; + device const float * y = yy + i * QK_K + il; + + const float d1 = d_all * x[i].scales[0]; + const float d2 = d_all * x[i].scales[1]; + const float d3 = d_all * x[i].scales[2]; + const float d4 = d_all * x[i].scales[3]; + + for (int l = 0; l < 4; ++l) { + const uint8_t hm = h[l] >> im; + sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4)) + + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4)) + + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4)) + + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4)); + } + + } + + sum[ith] = sumf; + +#endif // // Accumulate the sum from all threads in the threadgroup @@ -1371,10 +1437,6 @@ kernel void kernel_mul_mat_q4_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; @@ -1390,6 +1452,10 @@ kernel void kernel_mul_mat_q4_k_f32( #if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid - 4*il;// 0...3 @@ -1660,10 +1726,10 @@ kernel void kernel_mul_mat_q6_k_f32( float4 sums = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < 4; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32); - sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32); - sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32); - sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32); + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); + sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); } sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); } From 285eeb15314de97f96283419804a8f8f92a293ec Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 14:11:47 +0300 Subject: [PATCH 24/35] k_quants: WIP super-blocks with 64 weights Q5_K works on Metal and is slightly faster than QK_K = 256 (23.7 ms vs 26.3 ms). --- ggml-metal.metal | 65 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index ccd98956f8de6..42c3a041231ff 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -820,6 +820,13 @@ typedef struct { } block_q4_k; #endif +#if QK_K == 64 +typedef struct { + half4 d; // super-block scales/mins + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_k; +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins @@ -828,6 +835,7 @@ typedef struct { uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_k; // 176 bytes / block +#endif typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits @@ -1024,6 +1032,7 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i assert(k % QK_K == 0); const int nb = k / QK_K; +#if QK_K == 256 for (int i = 0; i < nb; i++) { const float d = (float)(x[i].d); @@ -1044,6 +1053,27 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i u1 <<= 2; u2 <<= 2; } } +#else + for (int i = 0; i < nb; i++) { + + const float4 d = (float4)x[i].d; + + device const uint8_t * ql = x[i].qs; + device const uint8_t * qh = x[i].qh; + + for (int l = 0; l < 8; ++l) { + y[l+ 0] = d[0] * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - d[1]; + y[l+ 8] = d[0] * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - d[1]; + y[l+16] = d[0] * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - d[1]; + y[l+24] = d[0] * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - d[1]; + y[l+32] = d[2] * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - d[3]; + y[l+40] = d[2] * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - d[3]; + y[l+48] = d[2] * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - d[3]; + y[l+56] = d[2] * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - d[3]; + } + y += QK_K; + } +#endif } @@ -1562,10 +1592,6 @@ kernel void kernel_mul_mat_q5_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; @@ -1577,6 +1603,14 @@ kernel void kernel_mul_mat_q5_k_f32( const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid - 4*il;// 0...3 @@ -1596,7 +1630,6 @@ kernel void kernel_mul_mat_q5_k_f32( uchar2 sc1, sc2, sc3, sc4; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q1 = (x + i)->qs + q_offset; @@ -1628,6 +1661,28 @@ kernel void kernel_mul_mat_q5_k_f32( sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } +#else + const int il = 4 * tpitg.x; // 0, 4, 8, 12 + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].qh + in; + device const float * y = yy + i*QK_K + il; + + const float4 d = (float4)x[i].d; + + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> im; + sumf += y[l+ 0] * (d[0] * ((q[l+ 0] & 0xF) + (hl & 0x01 ? 16 : 0)) - d[1]) + + y[l+16] * (d[0] * ((q[l+16] & 0xF) + (hl & 0x04 ? 16 : 0)) - d[1]) + + y[l+32] * (d[2] * ((q[l+ 0] >> 4) + (hl & 0x10 ? 16 : 0)) - d[3]) + + y[l+48] * (d[2] * ((q[l+16] >> 4) + (hl & 0x40 ? 16 : 0)) - d[3]); + } + } +#endif sum[ith] = sumf; // From 8b98d01e31daf2ad32671ba9446d858e5c4769de Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 14:16:12 +0300 Subject: [PATCH 25/35] k_quants: call them _K, not _k, also on Metal --- ggml-metal.m | 60 +++++++++++++++++++-------------------- ggml-metal.metal | 74 ++++++++++++++++++++++++------------------------ 2 files changed, 67 insertions(+), 67 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index f84a2c4337b52..7551231b9cf32 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -51,21 +51,21 @@ GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); - GGML_METAL_DECL_KERNEL(get_rows_q2_k); - GGML_METAL_DECL_KERNEL(get_rows_q3_k); - GGML_METAL_DECL_KERNEL(get_rows_q4_k); - GGML_METAL_DECL_KERNEL(get_rows_q5_k); - GGML_METAL_DECL_KERNEL(get_rows_q6_k); + GGML_METAL_DECL_KERNEL(get_rows_q2_K); + GGML_METAL_DECL_KERNEL(get_rows_q3_K); + GGML_METAL_DECL_KERNEL(get_rows_q4_K); + GGML_METAL_DECL_KERNEL(get_rows_q5_K); + GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); @@ -165,21 +165,21 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); - GGML_METAL_ADD_KERNEL(get_rows_q2_k); - GGML_METAL_ADD_KERNEL(get_rows_q3_k); - GGML_METAL_ADD_KERNEL(get_rows_q4_k); - GGML_METAL_ADD_KERNEL(get_rows_q5_k); - GGML_METAL_ADD_KERNEL(get_rows_q6_k); + GGML_METAL_ADD_KERNEL(get_rows_q2_K); + GGML_METAL_ADD_KERNEL(get_rows_q3_K); + GGML_METAL_ADD_KERNEL(get_rows_q4_K); + GGML_METAL_ADD_KERNEL(get_rows_q5_K); + GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); @@ -668,7 +668,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; } break; case GGML_TYPE_Q3_K: { @@ -677,7 +677,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; } break; case GGML_TYPE_Q4_K: { @@ -686,7 +686,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; } break; case GGML_TYPE_Q5_K: { @@ -695,7 +695,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; } break; case GGML_TYPE_Q6_K: { @@ -704,7 +704,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; } break; default: { @@ -756,11 +756,11 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 42c3a041231ff..3b4eac2bf4f65 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -792,7 +792,7 @@ typedef struct { uint8_t qs[QK_K/4]; // quants half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins -} block_q2_k; +} block_q2_K; // 84 bytes / block typedef struct { @@ -804,20 +804,20 @@ typedef struct { uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits #endif half d; // super-block scale -} block_q3_k; +} block_q3_K; #if QK_K == 64 typedef struct { half4 d; // super-block scales/mins uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_k; +} block_q4_K; #else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_k; +} block_q4_K; #endif #if QK_K == 64 @@ -825,7 +825,7 @@ typedef struct { half4 d; // super-block scales/mins uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_k; +} block_q5_K; #else typedef struct { half d; // super-block scale for quantized scales @@ -833,7 +833,7 @@ typedef struct { uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_k; +} block_q5_K; // 176 bytes / block #endif @@ -842,7 +842,7 @@ typedef struct { uint8_t qh[QK_K/4]; // quants, upper 2 bits int8_t scales[QK_K/16]; // scales, quantized with 8 bits half d; // super-block scale -} block_q6_k; +} block_q6_K; // 210 bytes / block static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { @@ -863,7 +863,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { //========================================== dequantization ============================= -static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) { +static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -910,7 +910,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i } } -static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) { +static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -992,7 +992,7 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i } -static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { +static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1028,7 +1028,7 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i } } -static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) { +static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1077,7 +1077,7 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i } -static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { +static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1123,7 +1123,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i } } -kernel void kernel_get_rows_q2_k( +kernel void kernel_get_rows_q2_K( device const void * src0, device const int * src1, device float * dst, @@ -1134,12 +1134,12 @@ kernel void kernel_get_rows_q2_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q2_k( - (device const block_q2_k *) ((device char *) src0 + r*nb01), + dequantize_row_q2_K( + (device const block_q2_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q3_k( +kernel void kernel_get_rows_q3_K( device const void * src0, device const int * src1, device float * dst, @@ -1150,12 +1150,12 @@ kernel void kernel_get_rows_q3_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q3_k( - (device const block_q3_k *) ((device char *) src0 + r*nb01), + dequantize_row_q3_K( + (device const block_q3_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q4_k( +kernel void kernel_get_rows_q4_K( device const void * src0, device const int * src1, device float * dst, @@ -1166,12 +1166,12 @@ kernel void kernel_get_rows_q4_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q4_k( - (device const block_q4_k *) ((device char *) src0 + r*nb01), + dequantize_row_q4_K( + (device const block_q4_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q5_k( +kernel void kernel_get_rows_q5_K( device const void * src0, device const int * src1, device float * dst, @@ -1182,12 +1182,12 @@ kernel void kernel_get_rows_q5_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q5_k( - (device const block_q5_k *) ((device char *) src0 + r*nb01), + dequantize_row_q5_K( + (device const block_q5_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q6_k( +kernel void kernel_get_rows_q6_K( device const void * src0, device const int * src1, device float * dst, @@ -1198,14 +1198,14 @@ kernel void kernel_get_rows_q6_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q6_k( - (device const block_q6_k *) ((device char *) src0 + r*nb01), + dequantize_row_q6_K( + (device const block_q6_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } //====================================== dot products ========================= -kernel void kernel_mul_mat_q2_k_f32( +kernel void kernel_mul_mat_q2_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1222,7 +1222,7 @@ kernel void kernel_mul_mat_q2_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb; + device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; @@ -1317,7 +1317,7 @@ kernel void kernel_mul_mat_q2_k_f32( } } -kernel void kernel_mul_mat_q3_k_f32( +kernel void kernel_mul_mat_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1341,7 +1341,7 @@ kernel void kernel_mul_mat_q3_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb; + device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; @@ -1455,7 +1455,7 @@ kernel void kernel_mul_mat_q3_k_f32( } -kernel void kernel_mul_mat_q4_k_f32( +kernel void kernel_mul_mat_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1475,7 +1475,7 @@ kernel void kernel_mul_mat_q4_k_f32( const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; - device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; + device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; float sumf = 0; @@ -1580,7 +1580,7 @@ kernel void kernel_mul_mat_q4_k_f32( //} } -kernel void kernel_mul_mat_q5_k_f32( +kernel void kernel_mul_mat_q5_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1597,7 +1597,7 @@ kernel void kernel_mul_mat_q5_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb; + device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; @@ -1704,7 +1704,7 @@ kernel void kernel_mul_mat_q5_k_f32( } -kernel void kernel_mul_mat_q6_k_f32( +kernel void kernel_mul_mat_q6_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1726,7 +1726,7 @@ kernel void kernel_mul_mat_q6_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; + device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; From 558a19427bd261a8aa87cd41bfd53f7d879202f2 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 14:35:21 +0300 Subject: [PATCH 26/35] k_quants: correctly define QK_K in llama.cpp --- llama.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/llama.cpp b/llama.cpp index ac22a48f8ab97..a794aa7259514 100644 --- a/llama.cpp +++ b/llama.cpp @@ -21,9 +21,13 @@ #endif #ifdef GGML_USE_K_QUANTS #ifndef QK_K +#ifdef GGML_QKK_64 +#define QK_K 64 +#else #define QK_K 256 #endif #endif +#endif #include #include From 333ffcc5baea8855555e56c3809b9745f09b3e67 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 17:46:11 +0300 Subject: [PATCH 27/35] Fixed bug in q4_K quantization added with the 64-block addition --- k_quants.c | 1 + 1 file changed, 1 insertion(+) diff --git a/k_quants.c b/k_quants.c index b05ef45b2e68e..88291abec2bc5 100644 --- a/k_quants.c +++ b/k_quants.c @@ -707,6 +707,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict uint8_t * q = y[i].qs; for (int j = 0; j < QK_K; j += 64) { for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; } x += QK_K; From 88412a1aa02a53a40146eb77c836104feaa01e73 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 17:50:53 +0300 Subject: [PATCH 28/35] Simplify via lambda --- llama.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index a794aa7259514..78eb8427c3e4a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2474,6 +2474,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector workers; std::mutex mutex; + auto use_more_bits = [] (int i_layer, int num_layers) -> bool { + return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; + }; + size_t idx = 0; for (llama_load_tensor & tensor : model_loader->tensors_map.tensors) { llama_buffer read_data; @@ -2528,15 +2532,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && - (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8 || - (i_attention_wv - n_attention_wv/8)%3 == 2)) new_type = GGML_TYPE_Q6_K; + use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; ++i_attention_wv; } else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && - (i_feed_forward_w2 < n_feed_forward_w2/8 || i_feed_forward_w2 >= 7*n_feed_forward_w2/8 || - (i_feed_forward_w2 - n_feed_forward_w2/8)%3 == 2)) new_type = GGML_TYPE_Q6_K; + use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; ++i_feed_forward_w2; } else if (tensor.name.find("attention.wo.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; From aeefd4e7812af36bf612cb9a36e903c58fd7ac80 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 24 Jun 2023 09:57:50 +0300 Subject: [PATCH 29/35] k_quants: swicth Q3_K to 4-bit scales when QK_K = 64 Otherwise there isn't much benefit from this quantization type. There is some very slight loss in accuracy, but we reduce size by ~7%. E.g., for OpenLLaMA-3B, Q3_K_S perplexity is 8.6131 with 8-bit scales and 8.6352 with 4-bit, while file size decreases from 1.53G to 1.44G. --- ggml-cuda.cu | 23 +++++++++------ ggml-metal.metal | 18 +++++------ k_quants.c | 77 ++++++++++++++++++++++++++++++++---------------- k_quants.h | 16 ++++++---- 4 files changed, 85 insertions(+), 49 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ec66c884a7001..22f13d3069d78 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -137,13 +137,13 @@ typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits #ifdef GGML_QKK_64 - int8_t scales[K_SCALE_SIZE]; // scales, quantized with 8 bits + uint8_t scales[2]; // scales, quantized with 8 bits #else uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits #endif half d; // super-block scale } block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); +//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); #ifdef GGML_QKK_64 typedef struct { @@ -448,8 +448,13 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { const uint8_t h = x[i].hmask[in] >> (2*is + im); const float d = (float)x[i].d; - y[ 0] = d * x[i].scales[is+0] * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); - y[32] = d * x[i].scales[is+2] * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + if (is == 0) { + y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } else { + y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } #endif } @@ -776,7 +781,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float const float * y = yy + i * QK_K + offset; const uint8_t * q = x[i].qs + offset; - const int8_t * s = x[i].scales; + const uint8_t * s = x[i].scales; const float dall = (float)x[i].d; @@ -784,10 +789,10 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { const uint8_t hl = x[i].hmask[im+l] >> in; const uint8_t ql = q[l]; - sum += y[l+ 0] * dall * s[0] * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) - + y[l+16] * dall * s[1] * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) - + y[l+32] * dall * s[2] * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) - + y[l+48] * dall * s[3] * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) + + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); } tmp += sum; } diff --git a/ggml-metal.metal b/ggml-metal.metal index 3b4eac2bf4f65..7185b7f2cf188 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -799,7 +799,7 @@ typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits #if QK_K == 64 - int8_t scales[K_SCALE_SIZE]; + uint8_t scales[2]; #else uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits #endif @@ -970,10 +970,10 @@ static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, i device const uint8_t * q = x[i].qs; device const uint8_t * hm = x[i].hmask; - const float d1 = d_all * x[i].scales[0]; - const float d2 = d_all * x[i].scales[1]; - const float d3 = d_all * x[i].scales[2]; - const float d4 = d_all * x[i].scales[3]; + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); for (int l = 0; l < 8; ++l) { uint8_t h = hm[l]; @@ -1417,10 +1417,10 @@ kernel void kernel_mul_mat_q3_K_f32( device const uint8_t * h = x[i].hmask + in; device const float * y = yy + i * QK_K + il; - const float d1 = d_all * x[i].scales[0]; - const float d2 = d_all * x[i].scales[1]; - const float d3 = d_all * x[i].scales[2]; - const float d4 = d_all * x[i].scales[3]; + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); for (int l = 0; l < 4; ++l) { const uint8_t hm = h[l] >> im; diff --git a/k_quants.c b/k_quants.c index 88291abec2bc5..c2c96d4a5da1e 100644 --- a/k_quants.c +++ b/k_quants.c @@ -469,21 +469,24 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict } #else if (max_scale) { - float iscale = -128.f/max_scale; - for (int j = 0; j < QK_K/16; ++j) { - int l = nearest_int(iscale*scales[j]); - l = MAX(-128, MIN(127, l)); - y[i].scales[j] = l; + float iscale = -8.f/max_scale; + for (int j = 0; j < QK_K/16; j+=2) { + int l1 = nearest_int(iscale*scales[j]); + l1 = 8 + MAX(-8, MIN(7, l1)); + int l2 = nearest_int(iscale*scales[j+1]); + l2 = 8 + MAX(-8, MIN(7, l2)); + y[i].scales[j/2] = l1 | (l2 << 4); } y[i].d = ggml_fp32_to_fp16(1/iscale); } else { - for (int j = 0; j < QK_K/16; ++j) { - y[i].scales[j] = 0; + for (int j = 0; j < QK_K/16; j+=2) { + y[i].scales[j/2] = 0; } y[i].d = ggml_fp32_to_fp16(0.f); } for (int j = 0; j < QK_K/16; ++j) { - float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; + int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4; + float d = ggml_fp16_to_fp32(y[i].d) * (s - 8); if (!d) { continue; } @@ -587,10 +590,10 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int const uint8_t * restrict q = x[i].qs; const uint8_t * restrict hm = x[i].hmask; - const float d1 = d_all * x[i].scales[0]; - const float d2 = d_all * x[i].scales[1]; - const float d3 = d_all * x[i].scales[2]; - const float d4 = d_all * x[i].scales[3]; + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); for (int l=0; l<8; ++l) { uint8_t h = hm[l]; @@ -1889,6 +1892,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri int8x16x4_t q3bytes; + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + float sum = 0; for (int i = 0; i < nb; ++i) { @@ -1899,8 +1905,13 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t q3bits = vld1q_u8(x[i].qs); const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); - const int8_t * restrict scales = x[i].scales; - int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[1] * y[i].bsums[1] + scales[2] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); const float d = y[i].d * (float)x[i].d; @@ -1917,8 +1928,8 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri #if defined(__ARM_FEATURE_DOTPROD) isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; #else const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), @@ -1929,7 +1940,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[1] + vaddvq_s16(p2) * scales[2] + vaddvq_s16(p3) * scales[3]; + isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3]; #endif sum += d * isum; @@ -1942,11 +1953,15 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const __m256i m3 = _mm256_set1_epi8(3); const __m256i m1 = _mm256_set1_epi8(1); + const __m256i m8 = _mm256_set1_epi16(8); __m256 acc = _mm256_setzero_ps(); uint64_t aux64; + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + for (int i = 0; i < nb; ++i) { const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); @@ -1954,14 +1969,18 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const uint8_t * restrict q3 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); - const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); + const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); - // Set up scales memcpy(&aux64, x[i].hmask, 8); - __m256i q3h_0 = _mm256_set_m128i(_mm_set_epi64x(aux64 >> 3, aux64 >> 2), _mm_set_epi64x(aux64 >> 1, aux64 >> 0)); - __m256i q3h_1 = _mm256_set_m128i(_mm_set_epi64x(aux64 >> 7, aux64 >> 6), _mm_set_epi64x(aux64 >> 5, aux64 >> 4)); + const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux); + __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); @@ -1969,8 +1988,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); // prepare low and high bits - const __m256i q3l_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits), m3); - const __m256i q3l_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q3bits, 6), _mm_srli_epi16(q3bits, 4)), m3); + const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits); + const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); // load Q8 quants const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); @@ -2007,6 +2027,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri int16_t aux16[8]; float sums [8]; int32_t aux32[8]; + int32_t scales[4]; memset(sums, 0, 8*sizeof(float)); float sumf = 0; @@ -2026,14 +2047,18 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); } + scales[0] = (x[i].scales[0] & 0xF) - 8; + scales[1] = (x[i].scales[0] >> 4) - 8; + scales[2] = (x[i].scales[1] & 0xF) - 8; + scales[3] = (x[i].scales[1] >> 4) - 8; + memset(aux32, 0, 8*sizeof(int32_t)); for (int j = 0; j < QK_K/16; ++j) { for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; q8 += 8; a += 8; - int sc = x[i].scales[j]; - for (int l = 0; l < 8; ++l) aux32[l] += sc * aux16[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; } const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; diff --git a/k_quants.h b/k_quants.h index 2f633389b6ec0..943b62ee57e0d 100644 --- a/k_quants.h +++ b/k_quants.h @@ -35,17 +35,23 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w // weight is represented as x = a * q // 16 blocks of 16 elemenets each // Effectively 3.4375 bits per weight +#ifdef GGML_QKK_64 typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits -#ifdef GGML_QKK_64 - int8_t scales[K_SCALE_SIZE]; + uint8_t scales[2]; + ggml_fp16_t d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); #else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[12]; // scales, quantized with 6 bits ggml_fp16_t d; // super-block scale } block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); +#endif // 4-bit quantization // 16 blocks of 32 elements each From ce19b965f0d4dd0d26d706f37ef99dce8749de59 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 24 Jun 2023 15:44:23 +0300 Subject: [PATCH 30/35] k_quants: switch Q4_K to 4-bit scales when QK_K = 64 Here the loss in accuracy is greater than for Q3_K, but the Q4_K points still move further to the left on the perplexity vs size curve. --- ggml-cuda.cu | 30 +++++++++----- k_quants.c | 112 +++++++++++++++++++++++++++++++++------------------ k_quants.h | 7 ++-- llama.cpp | 3 ++ 4 files changed, 98 insertions(+), 54 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 22f13d3069d78..840c0a8633ea6 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -147,10 +147,11 @@ typedef struct { #ifdef GGML_QKK_64 typedef struct { - half d[2*QK_K/32]; // super-block scales/mins + half d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); #else typedef struct { half d; // super-block scale for quantized scales @@ -503,8 +504,10 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { const int tid = threadIdx.x; const uint8_t * q = x[i].qs; float * y = yy + i*QK_K; - y[tid+ 0] = (float)x[i].d[0] * (q[tid] & 0xF) - (float)x[i].d[1]; - y[tid+32] = (float)x[i].d[2] * (q[tid] >> 4) - (float)x[i].d[3]; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); + y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); #endif } @@ -874,20 +877,25 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float #else const int step = tid * K_QUANTS_PER_ITERATION; + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + float tmp = 0; for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { const uint8_t * q = x[i].qs + step; const float * y = yy + i*QK_K + step; - const half2 * d = (const half2 *)x[i].d; - float2 df1 = __half22float2(d[0]); - float2 df2 = __half22float2(d[1]); + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; float sum = 0.f; for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - sum += y[j+ 0] * (df1.x * (q[j+ 0] & 0xF) - df1.y) - + y[j+16] * (df1.x * (q[j+16] & 0xF) - df1.y) - + y[j+32] * (df2.x * (q[j+ 0] >> 4) - df2.y) - + y[j+48] * (df2.x * (q[j+16] >> 4) - df2.y); + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); } tmp += sum; } diff --git a/k_quants.c b/k_quants.c index c2c96d4a5da1e..3d395e59f4898 100644 --- a/k_quants.c +++ b/k_quants.c @@ -635,14 +635,11 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict const int nb = k / QK_K; uint8_t L[QK_K]; -#if QK_K == 256 float mins[QK_K/32]; float scales[QK_K/32]; -#endif for (int i = 0; i < nb; i++) { -#if QK_K == 256 float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { @@ -657,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } +#if QK_K == 256 float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; float inv_min = max_min > 0 ? 63.f/max_min : 0.f; for (int j = 0; j < QK_K/32; ++j) { @@ -689,23 +687,37 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } #else - for (int j = 0; j < QK_K/32; ++j) { - float min; - float scale = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &min, 5); - y[i].d[2*j+0] = ggml_fp32_to_fp16(scale); - y[i].d[2*j+1] = ggml_fp32_to_fp16(min); - } + const float s_factor = 15.f; + float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f; + float inv_min = max_min > 0 ? s_factor/max_min : 0.f; + int d1 = nearest_int(inv_scale*scales[0]); + int m1 = nearest_int(inv_min*mins[0]); + int d2 = nearest_int(inv_scale*scales[1]); + int m2 = nearest_int(inv_min*mins[1]); + y[i].scales[0] = d1 | (m1 << 4); + y[i].scales[1] = d2 | (m2 << 4); + y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor); + y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor); + float sumlx = 0; + int suml2 = 0; for (int j = 0; j < QK_K/32; ++j) { - const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]); + const uint8_t sd = y[i].scales[j] & 0xF; + const uint8_t sm = y[i].scales[j] >> 4; + const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd; if (!d) continue; - const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]); + const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm; for (int ii = 0; ii < 32; ++ii) { - int l = nearest_int((x[32*j + ii] + dm)/d); + int l = nearest_int((x[32*j + ii] + m)/d); l = MAX(0, MIN(15, l)); L[32*j + ii] = l; + sumlx += (x[32*j + ii] + m)*l*sd; + suml2 += l*l*sd*sd; } } + if (suml2) { + y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2); + } #endif uint8_t * q = y[i].qs; for (int j = 0; j < QK_K; j += 64) { @@ -743,8 +755,10 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int q += 32; is += 2; } #else - float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]); - float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]); + const float dall = ggml_fp16_to_fp32(x[i].d[0]); + const float mall = ggml_fp16_to_fp32(x[i].d[1]); + const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4); + const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4); for (int l = 0; l < 32; ++l) { y[l+ 0] = d1 * (q[l] & 0xF) - m1; y[l+32] = d2 * (q[l] >> 4) - m2; @@ -1953,7 +1967,6 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const __m256i m3 = _mm256_set1_epi8(3); const __m256i m1 = _mm256_set1_epi8(1); - const __m256i m8 = _mm256_set1_epi16(8); __m256 acc = _mm256_setzero_ps(); @@ -2182,7 +2195,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { -#if QK_K == 256 const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); @@ -2192,10 +2204,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; -#else - // TODO - const float d = 0; const float dmin = 0; -#endif const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -2326,15 +2334,22 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri float sum_mins = 0.f; + uint16_t aux16[2]; + const uint8_t * restrict scales = (const uint8_t *)aux16; + for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const float32x4_t dsc = vcvt_f32_f16(vld1_f16(x[i].d)); - float summ = vgetq_lane_f32(dsc, 1) * (y[i].bsums[0] + y[i].bsums[1]) - + vgetq_lane_f32(dsc, 3) * (y[i].bsums[2] + y[i].bsums[3]); - sum_mins += y[i].d * summ; + const uint16_t * restrict a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); + sum_mins += y[i].d * (float)x[i].d[1] * summi; + + const float d = y[i].d * (float)x[i].d[0]; const uint8x16x2_t q4bits = vld1q_u8_x2(q4); @@ -2344,13 +2359,13 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - const float sumf1 = vaddvq_s32(p1) * vgetq_lane_f32(dsc, 0); + const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); - const float sumf2 = vaddvq_s32(p2) * vgetq_lane_f32(dsc, 2); + const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; #else q8bytes = vld1q_s8_x4(q8); @@ -2360,7 +2375,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - float sumf1 = vaddvq_s16(vaddq_s16(p0, p1)) * vgetq_lane_f32(dsc, 0); + int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); @@ -2368,10 +2383,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); - float sumf2 = vaddvq_s16(vaddq_s16(p2, p3)) * vgetq_lane_f32(dsc, 2); + int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; #endif - sumf += y[i].d * (sumf1 + sumf2); + sumf += d * (sumi1 + sumi2); } @@ -2380,16 +2395,25 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri #elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m1 = _mm256_set1_epi16(1); __m256 acc = _mm256_setzero_ps(); float summs = 0; + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + for (int i = 0; i < nb; ++i) { - summs += y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + - ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d; + const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -2404,11 +2428,11 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - const __m256i p32l = _mm256_madd_epi16(m1, p16l); - const __m256i p32h = _mm256_madd_epi16(m1, p16h); + const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[0])), _mm256_cvtepi32_ps(p32l), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * ggml_fp16_to_fp32(x[i].d[2])), _mm256_cvtepi32_ps(p32h), acc); + const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); } @@ -2421,6 +2445,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri float sums [8]; memset(sums, 0, 8*sizeof(float)); + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + float sumf = 0; for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].qs; @@ -2429,16 +2456,21 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; - sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + - ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); for (int j = 0; j < QK_K/32; ++j) { - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; q8 += 16; a += 16; for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; q8 += 16; a += 16; - for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[l+8]); + const float dl = d * scales[j]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); } } for (int l = 0; l < 8; ++l) sumf += sums[l]; diff --git a/k_quants.h b/k_quants.h index 943b62ee57e0d..6256ae167c378 100644 --- a/k_quants.h +++ b/k_quants.h @@ -59,10 +59,11 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + // Effectively 4.5 bits per weight #ifdef GGML_QKK_64 typedef struct { - ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins + ggml_fp16_t d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); #else typedef struct { ggml_fp16_t d; // super-block scale for quantized scales @@ -70,8 +71,8 @@ typedef struct { uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -#endif static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); +#endif // 5-bit quantization // 16 blocks of 32 elements each diff --git a/llama.cpp b/llama.cpp index 78eb8427c3e4a..c41c2a8a32992 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2533,12 +2533,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; + else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && + (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; ++i_attention_wv; } else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; + //else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K; ++i_feed_forward_w2; } else if (tensor.name.find("attention.wo.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; From 4f61506929843c7b082dddfd4bd31f1114603542 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 24 Jun 2023 15:47:27 +0300 Subject: [PATCH 31/35] k_quants: forgot to add the Metal changes in last commit --- ggml-metal.metal | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 7185b7f2cf188..0a170531b8036 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -808,7 +808,8 @@ typedef struct { #if QK_K == 64 typedef struct { - half4 d; // super-block scales/mins + half d[2]; // super-block scales/mins + uint8_t scales[2]; uint8_t qs[QK_K/2]; // 4-bit quants } block_q4_K; #else @@ -996,7 +997,6 @@ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, i assert(k % QK_K == 0); const int nb = k / QK_K; - for (int i = 0; i < nb; i++) { device const uint8_t * q = x[i].qs; @@ -1017,10 +1017,16 @@ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, i q += 32; is += 2; } #else - float4 d4 = (float4)x[i].d; + device const uint8_t * s = x[i].scales; + device const half2 * dh = (device const half2 *)x[i].d; + const float2 d = (float2)dh[0]; + const float d1 = d[0] * (s[0] & 0xF); + const float d2 = d[0] * (s[1] & 0xF); + const float m1 = d[1] * (s[0] >> 4); + const float m2 = d[1] * (s[1] >> 4); for (int l = 0; l < 32; ++l) { - y[l+ 0] = d4[0] * (q[l] & 0xF) - d4[1]; - y[l+32] = d4[2] * (q[l] >> 4) - d4[3]; + y[l+ 0] = d1 * (q[l] & 0xF) - m1; + y[l+32] = d2 * (q[l] >> 4) - m2; } y += QK_K; #endif @@ -1529,6 +1535,9 @@ kernel void kernel_mul_mat_q4_K_f32( } #else + uint16_t aux16[2]; + thread const uint8_t * scales = (thread const uint8_t *)aux16; + const int il = 4*tpitg.x; for (int i = tpitg.y; i < nb; i += tptg.y) { @@ -1536,11 +1545,16 @@ kernel void kernel_mul_mat_q4_K_f32( device const uint8_t * q = x[i].qs + il; device const float * y = yy + i * QK_K + il; - const float4 d4 = (float4)x[i].d; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + + device const uint16_t * a = (device const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; for (int l = 0; l < 4; ++l) { - sumf += d4[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - d4[1] * (y[l+ 0] + y[l+16]) - + d4[2] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - d4[3] * (y[l+32] + y[l+48]); + sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16]) + + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]); } } #endif From ccf4901334b8abaeded589a75eb2ecb09c1b81dd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 24 Jun 2023 17:39:25 +0300 Subject: [PATCH 32/35] k_quants: change Q5_K to be type 0 when QK_K = 64 Still needs AVX2 implementation --- ggml-cuda.cu | 32 +++++++------- ggml-metal.metal | 46 ++++++++++--------- k_quants.c | 113 +++++++++++++++++++++++++---------------------- k_quants.h | 5 ++- 4 files changed, 103 insertions(+), 93 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 840c0a8633ea6..a97d351a1dbf5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -164,11 +164,12 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, #ifdef GGML_QKK_64 typedef struct { - half d[2*QK_K/32]; // super-block scales/mins - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits + half d; // super-block scale + int8_t scales[QK_K/16]; // block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); #else typedef struct { half d; // super-block scale for quantized scales @@ -546,12 +547,14 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { #else const int tid = threadIdx.x; const uint8_t q = x[i].qs[tid]; - const int im = tid/8; // 0...3 - const int in = tid%8; // 0...7 + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const int is = tid/16; // 0 or 1 const uint8_t h = x[i].qh[in] >> im; + const float d = x[i].d; float * y = yy + i*QK_K + tid; - y[ 0] = (float)x[i].d[0] * ((q & 0xF) + ((h >> 0) & 1 ? 16 : 0)) - (float)x[i].d[1]; - y[32] = (float)x[i].d[2] * ((q >> 4) + ((h >> 4) & 1 ? 16 : 0)) - (float)x[i].d[3]; + y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); + y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); #endif } @@ -992,17 +995,16 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { const uint8_t * q = x[i].qs + step; + const int8_t * s = x[i].scales; const float * y = yy + i*QK_K + step; - const half2 * d = (const half2 *)x[i].d; - float2 df1 = __half22float2(d[0]); - float2 df2 = __half22float2(d[1]); + const float d = x[i].d; float sum = 0.f; for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { const uint8_t h = x[i].qh[in+j] >> im; - sum += y[j+ 0] * (df1.x * ((q[j+ 0] & 0xF) + (((h >> 0) & 1) << 4)) - df1.y) - + y[j+16] * (df1.x * ((q[j+16] & 0xF) + (((h >> 2) & 1) << 4)) - df1.y) - + y[j+32] * (df2.x * ((q[j+ 0] >> 4) + (((h >> 4) & 1) << 4)) - df2.y) - + y[j+48] * (df2.x * ((q[j+16] >> 4) + (((h >> 6) & 1) << 4)) - df2.y); + sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) + + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); } tmp += sum; } diff --git a/ggml-metal.metal b/ggml-metal.metal index 0a170531b8036..e62fe6842ea72 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -823,7 +823,8 @@ typedef struct { #if QK_K == 64 typedef struct { - half4 d; // super-block scales/mins + half d; // super-block scales/mins + int8_t scales[QK_K/16]; // 8-bit block scales uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; @@ -1062,20 +1063,21 @@ static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, i #else for (int i = 0; i < nb; i++) { - const float4 d = (float4)x[i].d; + const float d = (float)x[i].d; device const uint8_t * ql = x[i].qs; device const uint8_t * qh = x[i].qh; + device const int8_t * sc = x[i].scales; for (int l = 0; l < 8; ++l) { - y[l+ 0] = d[0] * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - d[1]; - y[l+ 8] = d[0] * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - d[1]; - y[l+16] = d[0] * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - d[1]; - y[l+24] = d[0] * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - d[1]; - y[l+32] = d[2] * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - d[3]; - y[l+40] = d[2] * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - d[3]; - y[l+48] = d[2] * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - d[3]; - y[l+56] = d[2] * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - d[3]; + y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); } y += QK_K; } @@ -1336,12 +1338,6 @@ kernel void kernel_mul_mat_q3_K_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const uint8_t m3 = 3; - const int8_t m4 = 4; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; @@ -1355,6 +1351,12 @@ kernel void kernel_mul_mat_q3_K_f32( #if QK_K == 256 + const uint8_t m3 = 3; + const int8_t m4 = 4; + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + const int tid = tpitg.y; // expecting 16 const int ip = tid/8; // 0 or 1 const int il = tid/2 - 4*ip; // 0...3 @@ -1682,18 +1684,18 @@ kernel void kernel_mul_mat_q5_K_f32( for (int i = tpitg.y; i < nb; i += tptg.y) { + const float d = (float)x[i].d; device const uint8_t * q = x[i].qs + il; device const uint8_t * h = x[i].qh + in; + device const int8_t * s = x[i].scales; device const float * y = yy + i*QK_K + il; - const float4 d = (float4)x[i].d; - for (int l = 0; l < 4; ++l) { const uint8_t hl = h[l] >> im; - sumf += y[l+ 0] * (d[0] * ((q[l+ 0] & 0xF) + (hl & 0x01 ? 16 : 0)) - d[1]) - + y[l+16] * (d[0] * ((q[l+16] & 0xF) + (hl & 0x04 ? 16 : 0)) - d[1]) - + y[l+32] * (d[2] * ((q[l+ 0] >> 4) + (hl & 0x10 ? 16 : 0)) - d[3]) - + y[l+48] * (d[2] * ((q[l+16] >> 4) + (hl & 0x40 ? 16 : 0)) - d[3]); + sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16)) + + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16)) + + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16)) + + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16)); } } #endif diff --git a/k_quants.c b/k_quants.c index 3d395e59f4898..a6221df9e6101 100644 --- a/k_quants.c +++ b/k_quants.c @@ -792,10 +792,13 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict assert(k % QK_K == 0); const int nb = k / QK_K; - uint8_t L[QK_K]; #if QK_K == 256 + uint8_t L[QK_K]; float mins[QK_K/32]; float scales[QK_K/32]; +#else + int8_t L[QK_K]; + float scales[QK_K/16]; #endif for (int i = 0; i < nb; i++) { @@ -869,20 +872,30 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict ql += 32; } #else - for (int j = 0; j < QK_K/32; ++j) { - float min; - float scale = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &min, 5); - y[i].d[2*j+0] = ggml_fp32_to_fp16(scale); - y[i].d[2*j+1] = ggml_fp32_to_fp16(min); + float max_scale = 0, amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); + float abs_scale = fabsf(scales[j]); + if (abs_scale > amax) { + amax = abs_scale; + max_scale = scales[j]; + } } - for (int j = 0; j < QK_K/32; ++j) { - const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]); + + float iscale = -128.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = MAX(-128, MIN(127, l)); + } + y[i].d = ggml_fp32_to_fp16(1/iscale); + + for (int j = 0; j < QK_K/16; ++j) { + const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; if (!d) continue; - const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]); - for (int ii = 0; ii < 32; ++ii) { - int l = nearest_int((x[32*j + ii] + dm)/d); - l = MAX(0, MIN(31, l)); - L[32*j + ii] = l; + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-16, MIN(15, l)); + L[16*j + ii] = l + 16; } } @@ -938,17 +951,17 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int u1 <<= 2; u2 <<= 2; } #else - float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]); - float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]); + float d = ggml_fp16_to_fp32(x[i].d); + const int8_t * restrict s = x[i].scales; for (int l = 0; l < 8; ++l) { - y[l+ 0] = d1 * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - m1; - y[l+ 8] = d1 * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - m1; - y[l+16] = d1 * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - m1; - y[l+24] = d1 * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - m1; - y[l+32] = d2 * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - m2; - y[l+40] = d2 * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - m2; - y[l+48] = d2 * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - m2; - y[l+56] = d2 * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - m2; + y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); } y += QK_K; #endif @@ -2751,19 +2764,12 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri float sumf = 0; - float32x4_t acc1 = vdupq_n_f32(0.f); - float32x4_t acc2 = vdupq_n_f32(0.f); - for (int i = 0; i < nb; ++i) { - const float16x4_t s16 = vld1_f16(x[i].d); - const float32x4_t s32 = vmulq_n_f32(vcvt_f32_f16(s16), y[i].d); - //const int16x4_t bi16 = vld1_s16(y[i].bsums); - //const int32x4_t bi32 = vmovl_s16(vpadd_s16(bi16, bi16)); - //const float32x4_t bf32 = vcvtq_f32_s32(bi32); - //sumf -= (vgetq_lane_f32(s32, 1) * vgetq_lane_f32(bf32, 0) + vgetq_lane_f32(s32, 3) * vgetq_lane_f32(bf32, 1)); - // The above is slightly slower than just this: - sumf -= (vgetq_lane_f32(s32, 1) * (y[i].bsums[0] + y[i].bsums[1]) + vgetq_lane_f32(s32, 3) * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * (float)x[i].d; + const int8_t * sc = x[i].scales; + + sumf -= 16.f * d * (sc[0] * y[i].bsums[0] + sc[1] * y[i].bsums[1] + sc[2] * y[i].bsums[2] + sc[3] * y[i].bsums[3]); const uint8_t * restrict q5 = x[i].qs; const uint8_t * restrict qh = x[i].qh; @@ -2787,34 +2793,35 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri #if defined(__ARM_FEATURE_DOTPROD) - acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])), - vgetq_lane_f32(s32, 0)); - acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])), - vgetq_lane_f32(s32, 2)); + int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); + int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); + int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); + int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + #else const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - const int16x8_t p01_16 = vaddq_s16(p0, p1); - const int32x4_t p01_32 = vaddq_s32(vmovl_s16(vget_low_s16(p01_16)), vmovl_s16(vget_high_s16(p01_16))); - acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(p01_32), vgetq_lane_f32(s32, 0)); + int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - const int16x8_t p02_16 = vaddq_s16(p2, p3); - const int32x4_t p02_32 = vaddq_s32(vmovl_s16(vget_low_s16(p02_16)), vmovl_s16(vget_high_s16(p02_16))); - acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(p02_32), vgetq_lane_f32(s32, 2)); + sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); + + sumf += d*sumi; #endif } - *s = vaddvq_f32(vaddq_f32(acc1, acc2)) + sumf; + *s = sumf; -#elif defined __AVX2__ +#elif defined z__AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); const __m256i m1 = _mm256_set1_epi16(1); @@ -2884,19 +2891,17 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri } for (int is = 0; is < 8; ++is) { uint8_t m = 1 << is; - for (int l = 0; l < 8; ++l) a[8*is + l] += (hm[l] & m ? 16 : 0); + for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); } - sumf -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + - ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const int8_t * restrict sc = x[i].scales; - for (int j = 0; j < QK_K/32; ++j) { - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[2*j]); + for (int j = 0; j < QK_K/16; ++j) { + const float dl = d * sc[j]; for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); q8 += 16; a += 16; - for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; - q8 += 16; a += 16; - for (int l = 0; l < 8; ++l) sums[l] += d * (aux16[l] + aux16[8+l]); } } for (int l = 0; l < 8; ++l) sumf += sums[l]; diff --git a/k_quants.h b/k_quants.h index 6256ae167c378..6abe3d7b8f422 100644 --- a/k_quants.h +++ b/k_quants.h @@ -80,11 +80,12 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/ // Effectively 5.5 bits per weight #ifdef GGML_QKK_64 typedef struct { - ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins + ggml_fp16_t d; // super-block scale + int8_t scales[QK_K/16]; // 8-bit block scales uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); #else typedef struct { ggml_fp16_t d; // super-block scale for quantized scales From 2da3a597082ed7da9ab1a4363932518772024a41 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 24 Jun 2023 18:17:35 +0300 Subject: [PATCH 33/35] k_quants: AVX2 implementation for new 64-weight Q5_K --- k_quants.c | 44 ++++++++++++++++++++------------------------ 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/k_quants.c b/k_quants.c index a6221df9e6101..d913b8760ef00 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2821,55 +2821,51 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri *s = sumf; -#elif defined z__AVX2__ +#elif defined __AVX2__ const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m1 = _mm256_set1_epi16(1); const __m256i mone = _mm256_set1_epi8(1); __m256 acc = _mm256_setzero_ps(); - float summs = 0.f; - for (int i = 0; i < nb; ++i) { const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; - const float d1 = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); - const float d2 = y[i].d * ggml_fp16_to_fp32(x[i].d[2]); - summs -= y[i].d * (ggml_fp16_to_fp32(x[i].d[1]) * (y[i].bsums[0] + y[i].bsums[1]) + - ggml_fp16_to_fp32(x[i].d[3]) * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)(q5+ 0)); + const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); + const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); - const uint64_t * restrict sc = (const uint64_t *)x[i].qh; - const __m128i hbits_0 = _mm_set_epi64x(sc[0] >> 1, sc[0] >> 0); - const __m128i hbits_1 = _mm_set_epi64x(sc[0] >> 3, sc[0] >> 2); - const __m128i hbits_2 = _mm_set_epi64x(sc[0] >> 5, sc[0] >> 4); - const __m128i hbits_3 = _mm_set_epi64x(sc[0] >> 7, sc[0] >> 6); + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128); - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_1, hbits_0), mone), 4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(hbits_3, hbits_2), mone), 4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); - const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - const __m256i p16_0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_0, q8_0)); - const __m256i p16_1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q5_1, q8_1)); + const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); + const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); + const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); + const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); + + const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d1), _mm256_cvtepi32_ps(p16_0), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d2), _mm256_cvtepi32_ps(p16_1), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); } - *s = hsum_float_8(acc) + summs; + *s = hsum_float_8(acc); #else From 53e81ca28944f0d86910a30ee959f2edfa24277b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 24 Jun 2023 18:32:38 +0300 Subject: [PATCH 34/35] k_quants: 10% faster ARM_NEON Q5_K dot product --- k_quants.c | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/k_quants.c b/k_quants.c index d913b8760ef00..46dd884b0aa80 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2769,8 +2769,6 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const float d = y[i].d * (float)x[i].d; const int8_t * sc = x[i].scales; - sumf -= 16.f * d * (sc[0] * y[i].bsums[0] + sc[1] * y[i].bsums[1] + sc[2] * y[i].bsums[2] + sc[3] * y[i].bsums[3]); - const uint8_t * restrict q5 = x[i].qs; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; @@ -2781,15 +2779,15 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const int8x16x4_t q8bytes = vld1q_s8_x4(q8); const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); - q5h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 4)); - q5h.val[1] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); - q5h.val[2] = vandq_u8(mh, htmp); - q5h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); - - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vbicq_u8(mh, htmp); + q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); + + q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); + q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); + q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); + q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); #if defined(__ARM_FEATURE_DOTPROD) From 5fd83379ff1b37a1ebfd6e3cd9f5672990d30eb5 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 26 Jun 2023 13:46:24 +0300 Subject: [PATCH 35/35] k_quants: fixed issue caused by merging with master --- ggml-cuda.cu | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a97d351a1dbf5..a8b4a2e368141 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -706,7 +706,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } @@ -823,9 +823,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - const block_q4_K * x = (const block_q4_K *)vx + ib0; #if QK_K == 256 @@ -833,6 +830,9 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 const int il = tid/step; // 0...3 @@ -878,6 +878,9 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float } #else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; uint16_t aux16[2];