From 08d8a6b5286afdd9ff930c6a08abba8abef78902 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 19 May 2024 23:28:06 +0200 Subject: [PATCH 1/8] f16 still works --- ggml-cuda/dequantize.cuh | 8 ++++++++ ggml-cuda/dmmv.cu | 8 -------- ggml-cuda/fattn-common.cuh | 1 - ggml-cuda/fattn-tile-f16.cu | 14 +++++++++----- ggml-cuda/fattn.cu | 7 +++++++ 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/ggml-cuda/dequantize.cuh b/ggml-cuda/dequantize.cuh index bd3c2d9db9463..4c735e9774282 100644 --- a/ggml-cuda/dequantize.cuh +++ b/ggml-cuda/dequantize.cuh @@ -101,3 +101,11 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in v.y *= d; #endif // GGML_CUDA_F16 } + +static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const half * x = (const half *) vx; + + // automatic half -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 7313e3e175367..be02b688d175b 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -565,14 +565,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, } } -static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ - const half * x = (const half *) vx; - - // automatic half -> float type cast if dfloat == float - v.x = x[ib + iqs + 0]; - v.y = x[ib + iqs + 1]; -} - template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 1dd519bdee7f1..4adbcc6f4d4e8 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -94,7 +94,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_tensor * KQV = dst; GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(KQV->type == GGML_TYPE_F32); diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 4a07ac6adad71..694ba81baf268 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -1,10 +1,11 @@ #include "common.cuh" +#include "dequantize.cuh" #include "fattn-common.cuh" #include "fattn-tile-f16.cuh" #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -48,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16( const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; + const int stride_K = nb11 / sizeof(type_k); const int stride_KV2 = nb11 / sizeof(half2); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); @@ -108,7 +110,9 @@ static __global__ void flash_attn_tile_ext_f16( for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; - KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + half2 tmp; + dequantize_k(K_h, (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, (2*k_KQ)%qkk, tmp); + KV_tmp[i_KQ][k_KQ] = tmp; } } @@ -270,13 +274,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: { diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index af7c95232ddf3..e6159ec69f425 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -457,11 +457,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[2]; + if (ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) { + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + return; + } + // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { if (precision == GGML_PREC_DEFAULT) { From 1dd185751ea6bbdee418b4d90b63d9a948f15db7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 00:23:43 +0200 Subject: [PATCH 2/8] q8_0 k works --- ggml-cuda/fattn-tile-f16.cu | 36 +++++++++++++++++++++++++++--------- ggml-cuda/fattn.cu | 2 +- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 694ba81baf268..1765ef183412e 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -50,11 +50,11 @@ static __global__ void flash_attn_tile_ext_f16( const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; const int stride_K = nb11 / sizeof(type_k); - const int stride_KV2 = nb11 / sizeof(half2); + const int stride_KV2 = nb11*qkk / (2*sizeof(type_k)); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -111,7 +111,7 @@ static __global__ void flash_attn_tile_ext_f16( const int k_KQ = k_KQ_0 + threadIdx.x; half2 tmp; - dequantize_k(K_h, (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, (2*k_KQ)%qkk, tmp); + dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, (2*k_KQ)%qkk, tmp); KV_tmp[i_KQ][k_KQ] = tmp; } } @@ -267,20 +267,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: { @@ -289,6 +289,24 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } } + +template +void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + + switch (K->type) { + case GGML_TYPE_Q8_0: + launch_fattn_tile_f16_64_128(ctx, dst); + break; + case GGML_TYPE_F16: + launch_fattn_tile_f16_64_128(ctx, dst); + break; + default: + GGML_ASSERT(false); + break; + } +} + void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -299,18 +317,18 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_K_type(ctx, dst); return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_K_type(ctx, dst); return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_K_type(ctx, dst); } diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index e6159ec69f425..046bed12986b3 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -464,7 +464,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[2]; - if (ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) { + if (true || ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) { ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); return; } From 1b49f47c2257038f8978c7bb7c85ea706619d15a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 10:47:38 +0200 Subject: [PATCH 3/8] q4_0 works --- ggml-cuda/fattn-tile-f16.cu | 43 ++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 1765ef183412e..95e7022ee2275 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -5,7 +5,8 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -48,7 +49,8 @@ static __global__ void flash_attn_tile_ext_f16( const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const float2 * Q_f2 = (const float2 *) Q_f; const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; @@ -81,12 +83,26 @@ static __global__ void flash_attn_tile_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (qrk == 1) { #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; - Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { + const int i = i0 + 2*threadIdx.x; + const int iqs = (i%qkk)/qrk; + const int iybs = i - i%qkk; + + float2 tmp; + tmp.x = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 0*qkk/2]; + tmp.y = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 1*qkk/2]; + Q_h2[j][i/2] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } } } @@ -111,7 +127,7 @@ static __global__ void flash_attn_tile_ext_f16( const int k_KQ = k_KQ_0 + threadIdx.x; half2 tmp; - dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, (2*k_KQ)%qkk, tmp); + dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, ((2*k_KQ)%qkk)/qrk, tmp); KV_tmp[i_KQ][k_KQ] = tmp; } } @@ -267,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: { @@ -295,11 +311,14 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * K = dst->src[1]; switch (K->type) { + case GGML_TYPE_Q4_0: + launch_fattn_tile_f16_64_128(ctx, dst); + break; case GGML_TYPE_Q8_0: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128(ctx, dst); break; case GGML_TYPE_F16: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128(ctx, dst); break; default: GGML_ASSERT(false); From ca6d82885cf8bcb87053519b6f8dfd6ea30aec47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 10:54:42 +0200 Subject: [PATCH 4/8] FP16 V still works --- ggml-cuda/fattn-tile-f16.cu | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 95e7022ee2275..0bcfda7b464ec 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -5,8 +5,9 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -283,20 +284,24 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16< + D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16< + D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: { @@ -305,6 +310,20 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } } +template +void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * V = dst->src[2]; + + switch (V->type) { + case GGML_TYPE_F16: + launch_fattn_tile_f16_64_128(ctx, dst); + break; + default: + GGML_ASSERT(false); + break; + } +} template void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -312,13 +331,13 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * switch (K->type) { case GGML_TYPE_Q4_0: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_V_type(ctx, dst); break; case GGML_TYPE_Q8_0: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_V_type(ctx, dst); break; case GGML_TYPE_F16: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_V_type(ctx, dst); break; default: GGML_ASSERT(false); From 8a10e5c03cd6cf1be463ab57652b9e8f195663d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 11:22:11 +0200 Subject: [PATCH 5/8] FP16 v still works --- ggml-cuda/fattn-tile-f16.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 0bcfda7b464ec..2e3baa793434e 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -53,11 +53,11 @@ static __global__ void flash_attn_tile_ext_f16( const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const float2 * Q_f2 = (const float2 *) Q_f; const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(half)*qkk/sizeof(type_k)); // K and V have same shape + const type_v * V_h = (const type_v *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(type_v)*qkk/(sizeof(type_k)*qkv)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const int stride_K = nb11 / sizeof(type_k); - const int stride_KV2 = nb11*qkk / (2*sizeof(type_k)); + const int stride_K = nb11 / sizeof(type_k); + const int stride_V = nb11*qkk / (sizeof(type_v)*qkv); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -217,7 +217,9 @@ static __global__ void flash_attn_tile_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i]; + half2 tmp; + dequantize_v(V_h + (k_VKQ_0 + k)*stride_V + (2*i)/qkv, 0, ((2*i)%qkv)/qrv, tmp); + KV_tmp[k][i] = tmp; } } From 75096c6e6ee51baffaa4de7c7f2fc889065fab9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 11:29:29 +0200 Subject: [PATCH 6/8] q8_0 works --- ggml-cuda/fattn-common.cuh | 1 - ggml-cuda/fattn-tile-f16.cu | 11 ++++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 4adbcc6f4d4e8..e08119c14c3ab 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -94,7 +94,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_tensor * KQV = dst; GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 2e3baa793434e..734d676c2d809 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -56,8 +56,8 @@ static __global__ void flash_attn_tile_ext_f16( const type_v * V_h = (const type_v *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(type_v)*qkk/(sizeof(type_k)*qkv)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const int stride_K = nb11 / sizeof(type_k); - const int stride_V = nb11*qkk / (sizeof(type_v)*qkv); + const int stride_K = nb11/sizeof(type_k); + const int stride_V = stride_K*qkk/qkv; const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -318,8 +318,13 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * V = dst->src[2]; switch (V->type) { + case GGML_TYPE_Q8_0: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst); + break; case GGML_TYPE_F16: - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst); break; default: GGML_ASSERT(false); From 14e80c413bb411143f82f47f34b7e86c1575a7a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 11:46:02 +0200 Subject: [PATCH 7/8] q4_0 works --- ggml-cuda/fattn-tile-f16.cu | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 734d676c2d809..c6925c2a6a1d6 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -273,8 +273,16 @@ static __global__ void flash_attn_tile_ext_f16( dst_val /= __half2half2(kqsum_j); } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val); - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val); + + if (qrv == 1) { + dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val); + dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val); + } else { + const int iqs = (i0%qkv)/qrv; + const int iybs = i0 - i0%qkv; + dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 0*(qkv/2)] = __low2float(dst_val); + dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 1*(qkv/2)] = __high2float(dst_val); + } } if (parallel_blocks != 1 && threadIdx.x == 0) { @@ -318,6 +326,10 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * V = dst->src[2]; switch (V->type) { + case GGML_TYPE_Q4_0: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst); + break; case GGML_TYPE_Q8_0: launch_fattn_tile_f16_64_128< cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst); From 75ef7619a258a95d26f8d68cbc81216bd4cee419 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 May 2024 11:52:10 +0200 Subject: [PATCH 8/8] add q4_1 q5_0 q5_1 support --- ggml-cuda/fattn-tile-f16.cu | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index c6925c2a6a1d6..90746a7111913 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -330,6 +330,18 @@ void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * launch_fattn_tile_f16_64_128< cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst); break; + case GGML_TYPE_Q4_1: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_1, QK4_1, QR4_1, dequantize_q4_1>(ctx, dst); + break; + case GGML_TYPE_Q5_0: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q5_0, QK5_0, QR5_0, dequantize_q5_0>(ctx, dst); + break; + case GGML_TYPE_Q5_1: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q5_1, QK5_1, QR5_1, dequantize_q5_1>(ctx, dst); + break; case GGML_TYPE_Q8_0: launch_fattn_tile_f16_64_128< cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst); @@ -352,6 +364,15 @@ void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * case GGML_TYPE_Q4_0: launch_fattn_tile_f16_V_type(ctx, dst); break; + case GGML_TYPE_Q4_1: + launch_fattn_tile_f16_V_type(ctx, dst); + break; + case GGML_TYPE_Q5_0: + launch_fattn_tile_f16_V_type(ctx, dst); + break; + case GGML_TYPE_Q5_1: + launch_fattn_tile_f16_V_type(ctx, dst); + break; case GGML_TYPE_Q8_0: launch_fattn_tile_f16_V_type(ctx, dst); break;