diff --git a/ggml-cuda/dequantize.cuh b/ggml-cuda/dequantize.cuh index bd3c2d9db9463..4c735e9774282 100644 --- a/ggml-cuda/dequantize.cuh +++ b/ggml-cuda/dequantize.cuh @@ -101,3 +101,11 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in v.y *= d; #endif // GGML_CUDA_F16 } + +static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const half * x = (const half *) vx; + + // automatic half -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 7313e3e175367..be02b688d175b 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -565,14 +565,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, } } -static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ - const half * x = (const half *) vx; - - // automatic half -> float type cast if dfloat == float - v.x = x[ib + iqs + 0]; - v.y = x[ib + iqs + 1]; -} - template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 1dd519bdee7f1..e08119c14c3ab 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -94,8 +94,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_tensor * KQV = dst; GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(K->type == GGML_TYPE_F16); - GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 4a07ac6adad71..90746a7111913 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -1,10 +1,13 @@ #include "common.cuh" +#include "dequantize.cuh" #include "fattn-common.cuh" #include "fattn-tile-f16.cuh" #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -47,12 +50,14 @@ static __global__ void flash_attn_tile_ext_f16( const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const float2 * Q_f2 = (const float2 *) Q_f; + const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); + const type_v * V_h = (const type_v *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(type_v)*qkk/(sizeof(type_k)*qkv)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const int stride_KV2 = nb11 / sizeof(half2); + const int stride_K = nb11/sizeof(type_k); + const int stride_V = stride_K*qkk/qkv; const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -79,12 +84,26 @@ static __global__ void flash_attn_tile_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (qrk == 1) { #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; - Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { + const int i = i0 + 2*threadIdx.x; + const int iqs = (i%qkk)/qrk; + const int iybs = i - i%qkk; + + float2 tmp; + tmp.x = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 0*qkk/2]; + tmp.y = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 1*qkk/2]; + Q_h2[j][i/2] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } } } @@ -108,7 +127,9 @@ static __global__ void flash_attn_tile_ext_f16( for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; - KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + half2 tmp; + dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, ((2*k_KQ)%qkk)/qrk, tmp); + KV_tmp[i_KQ][k_KQ] = tmp; } } @@ -196,7 +217,9 @@ static __global__ void flash_attn_tile_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i]; + half2 tmp; + dequantize_v(V_h + (k_VKQ_0 + k)*stride_V + (2*i)/qkv, 0, ((2*i)%qkv)/qrv, tmp); + KV_tmp[k][i] = tmp; } } @@ -250,8 +273,16 @@ static __global__ void flash_attn_tile_ext_f16( dst_val /= __half2half2(kqsum_j); } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val); - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val); + + if (qrv == 1) { + dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val); + dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val); + } else { + const int iqs = (i0%qkv)/qrv; + const int iybs = i0 - i0%qkv; + dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 0*(qkv/2)] = __low2float(dst_val); + dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 1*(qkv/2)] = __high2float(dst_val); + } } if (parallel_blocks != 1 && threadIdx.x == 0) { @@ -263,20 +294,24 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16< + D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16< + D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: { @@ -285,6 +320,71 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } } +template +void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * V = dst->src[2]; + + switch (V->type) { + case GGML_TYPE_Q4_0: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst); + break; + case GGML_TYPE_Q4_1: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_1, QK4_1, QR4_1, dequantize_q4_1>(ctx, dst); + break; + case GGML_TYPE_Q5_0: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q5_0, QK5_0, QR5_0, dequantize_q5_0>(ctx, dst); + break; + case GGML_TYPE_Q5_1: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q5_1, QK5_1, QR5_1, dequantize_q5_1>(ctx, dst); + break; + case GGML_TYPE_Q8_0: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst); + break; + case GGML_TYPE_F16: + launch_fattn_tile_f16_64_128< + cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst); + break; + default: + GGML_ASSERT(false); + break; + } +} + +template +void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + + switch (K->type) { + case GGML_TYPE_Q4_0: + launch_fattn_tile_f16_V_type(ctx, dst); + break; + case GGML_TYPE_Q4_1: + launch_fattn_tile_f16_V_type(ctx, dst); + break; + case GGML_TYPE_Q5_0: + launch_fattn_tile_f16_V_type(ctx, dst); + break; + case GGML_TYPE_Q5_1: + launch_fattn_tile_f16_V_type(ctx, dst); + break; + case GGML_TYPE_Q8_0: + launch_fattn_tile_f16_V_type(ctx, dst); + break; + case GGML_TYPE_F16: + launch_fattn_tile_f16_V_type(ctx, dst); + break; + default: + GGML_ASSERT(false); + break; + } +} + void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -295,18 +395,18 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_K_type(ctx, dst); return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_K_type(ctx, dst); return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f16_64_128(ctx, dst); + launch_fattn_tile_f16_K_type(ctx, dst); } diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index af7c95232ddf3..046bed12986b3 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -457,11 +457,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[2]; + if (true || ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) { + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + return; + } + // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { if (precision == GGML_PREC_DEFAULT) {