Skip to content

Commit 743cfdf

Browse files
vulkan: initial support for IQ4_XS quantization
1 parent eb7cf15 commit 743cfdf

13 files changed

+169
-13
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 25 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
1212
#endif
1313

1414
void main() {
15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem(gl_WorkGroupSize);
1717
if (gl_LocalInvocationIndex.x != 0) {
1818
return;

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx)
217217
#endif
218218

219219
void main() {
220-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
220+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
221221
init_iq_shmem(gl_WorkGroupSize);
222222
if (gl_LocalInvocationIndex.x != 0) {
223223
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,42 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
304304
}
305305
#endif
306306

307+
#if defined(DATA_A_IQ4_XS)
308+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
309+
const uint ib32 = iqs / 32;
310+
const uint iq = 16 * ib32 + (iqs % 16);
311+
312+
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
313+
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
314+
const uint qshift = (iqs & 16) >> 2;
315+
u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]);
316+
qs = (qs >> qshift) & uint8_t(0xF);
317+
318+
const float dl = float(int(sl | (sh << 4)) - 32);
319+
return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
320+
}
321+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
322+
const uint ib32 = iqs / 32;
323+
const uint iq = 16 * ib32 + (iqs % 16);
324+
325+
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
326+
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
327+
const uint qshift = (iqs & 16) >> 2;
328+
u8vec4 qs = u8vec4(
329+
data_a[a_offset + ib].qs[iq + 0],
330+
data_a[a_offset + ib].qs[iq + 1],
331+
data_a[a_offset + ib].qs[iq + 2],
332+
data_a[a_offset + ib].qs[iq + 3]
333+
);
334+
qs = (qs >> qshift) & uint8_t(0xF);
335+
336+
const float dl = float(int(sl | (sh << 4)) - 32);
337+
return dl * vec4(
338+
kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y],
339+
kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
340+
}
341+
#endif
342+
307343
#if defined(DATA_A_IQ4_NL)
308344
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
309345
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -321,7 +357,7 @@ vec2 get_dm(uint ib, uint a_offset) {
321357
}
322358
#endif
323359

324-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
360+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
325361
vec2 get_dm(uint ib, uint a_offset) {
326362
return vec2(float(data_a[a_offset + ib].d), 0);
327363
}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,27 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords
454454
}
455455
#endif
456456

457+
#if defined(DATA_A_IQ4_XS)
458+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
459+
block_iq4_xs block;
460+
};
461+
462+
float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
463+
{
464+
const float16_t d = bl.block.d;
465+
const uint idx = coordInBlock[1];
466+
467+
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
468+
469+
const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
470+
const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
471+
const uint qshift = (idx & 16) >> 2;
472+
const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
473+
474+
float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
475+
return ret;
476+
}
477+
#endif
457478

458479
#if defined(DATA_A_IQ4_NL)
459480
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
@@ -504,6 +525,8 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
504525
#define dequantFuncA dequantFuncIQ3_XXS
505526
#elif defined(DATA_A_IQ3_S)
506527
#define dequantFuncA dequantFuncIQ3_S
528+
#elif defined(DATA_A_IQ4_XS)
529+
#define dequantFuncA dequantFuncIQ4_XS
507530
#elif defined(DATA_A_IQ4_NL)
508531
#define dequantFuncA dequantFuncIQ4_NL
509532
#endif
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#version 450
2+
3+
#include "dequant_head.comp"
4+
5+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
6+
7+
layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
void main() {
11+
// Each thread handles 1 subblock (1 scale and 32 quantized values)
12+
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
13+
14+
init_iq_shmem(gl_WorkGroupSize);
15+
16+
if (ib >= p.nel / 256) {
17+
return;
18+
}
19+
20+
const uint ib32 = gl_LocalInvocationID.x % 8;
21+
22+
const float d = float(data_a[ib].d);
23+
// Scales are 6 bits
24+
const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF)
25+
| (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4);
26+
const float dl = d * (int(scale) - 32);
27+
28+
const uint b_idx = 256 * ib + 32 * ib32;
29+
const uint q_idx = 16 * ib32;
30+
[[unroll]] for (uint l = 0; l < 16; ++l) {
31+
data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
32+
data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
33+
}
34+
}

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
104104
#endif
105105

106106
void main() {
107-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
107+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
108108
init_iq_shmem(gl_WorkGroupSize);
109109
#endif
110110

ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void main() {
1212
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
1313
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
1414

15-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
15+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
1616
init_iq_shmem(gl_WorkGroupSize);
1717
#endif
1818

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
133133
void main() {
134134
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
135135

136-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
136+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
137137
init_iq_shmem(gl_WorkGroupSize);
138138
#endif
139139

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
9595
#endif
9696

9797
void main() {
98-
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
98+
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
9999
init_iq_shmem(gl_WorkGroupSize);
100100
#endif
101101

@@ -547,6 +547,25 @@ void main() {
547547
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
548548
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
549549

550+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
551+
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
552+
#elif defined(DATA_A_IQ4_XS)
553+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
554+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
555+
556+
const uint ib = idx / 128; // 2 values per idx
557+
const uint ib32 = (idx % 128) / 16; // 0..7
558+
const uint iq = 16 * ib32 + 2 * (idx % 8);
559+
560+
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
561+
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
562+
const uint qshift = (idx & 8) >> 1;
563+
u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
564+
qs = (qs >> qshift) & uint8_t(0xF);
565+
566+
const float d = float(data_a[ib].d);
567+
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
568+
550569
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
551570
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
552571
#elif defined(DATA_A_IQ4_NL)

0 commit comments

Comments
 (0)