Skip to content

CUDA: quantized KV cache demo #7412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ggml-cuda/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,11 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
v.y *= d;
#endif // GGML_CUDA_F16
}

static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx;

// automatic half -> float type cast if dfloat == float
v.x = x[ib + iqs + 0];
v.y = x[ib + iqs + 1];
}
8 changes: 0 additions & 8 deletions ggml-cuda/dmmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -565,14 +565,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
}
}

static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx;

// automatic half -> float type cast if dfloat == float
v.x = x[ib + iqs + 0];
v.y = x[ib + iqs + 1];
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
Expand Down
2 changes: 0 additions & 2 deletions ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
ggml_tensor * KQV = dst;

GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);

GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
Expand Down
138 changes: 119 additions & 19 deletions ggml-cuda/fattn-tile-f16.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#include "common.cuh"
#include "dequantize.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f16.cuh"

#define FATTN_KQ_STRIDE_TILE_F16 64

template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
template<int D, int ncols, int nwarps, int parallel_blocks, // D == head size
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k,
typename type_v, int qkv, int qrv, dequantize_kernel_t dequantize_v>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
Expand Down Expand Up @@ -47,12 +50,14 @@ static __global__ void flash_attn_tile_ext_f16(
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.

const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const float2 * Q_f2 = (const float2 *) Q_f;
const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio));
const type_v * V_h = (const type_v *) (V + nb12*(blockIdx.y / gqa_ratio)*sizeof(type_v)*qkk/(sizeof(type_k)*qkv)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;

const int stride_KV2 = nb11 / sizeof(half2);
const int stride_K = nb11/sizeof(type_k);
const int stride_V = stride_K*qkk/qkv;

const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
Expand All @@ -79,12 +84,26 @@ static __global__ void flash_attn_tile_ext_f16(
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;

if (qrk == 1) {
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
} else {
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
const int i = i0 + 2*threadIdx.x;
const int iqs = (i%qkk)/qrk;
const int iybs = i - i%qkk;

float2 tmp;
tmp.x = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 0*qkk/2];
tmp.y = Q_f[j*(nb01/sizeof(float)) + iybs + iqs + 1*qkk/2];
Q_h2[j][i/2] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
}
}

Expand All @@ -108,7 +127,9 @@ static __global__ void flash_attn_tile_ext_f16(
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;

KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
half2 tmp;
dequantize_k(K_h + (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, 0, ((2*k_KQ)%qkk)/qrk, tmp);
KV_tmp[i_KQ][k_KQ] = tmp;
}
}

Expand Down Expand Up @@ -196,7 +217,9 @@ static __global__ void flash_attn_tile_ext_f16(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
half2 tmp;
dequantize_v(V_h + (k_VKQ_0 + k)*stride_V + (2*i)/qkv, 0, ((2*i)%qkv)/qrv, tmp);
KV_tmp[k][i] = tmp;
}
}

Expand Down Expand Up @@ -250,8 +273,16 @@ static __global__ void flash_attn_tile_ext_f16(
dst_val /= __half2half2(kqsum_j);
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);

if (qrv == 1) {
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
} else {
const int iqs = (i0%qkv)/qrv;
const int iybs = i0 - i0%qkv;
dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 0*(qkv/2)] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + iybs + iqs + 1*(qkv/2)] = __high2float(dst_val);
}
}

if (parallel_blocks != 1 && threadIdx.x == 0) {
Expand All @@ -263,20 +294,24 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // FP16_AVAILABLE
}

template <int cols_per_block, int parallel_blocks>
template <int cols_per_block, int parallel_blocks,
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k,
typename type_v, int qkv, int qrv, dequantize_kernel_t dequantize_v>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<
D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<
D, cols_per_block, nwarps, parallel_blocks, type_k, qkk, qrk, dequantize_k, type_v, qkv, qrv, dequantize_v>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
} break;
default: {
Expand All @@ -285,6 +320,71 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}

template <int cols_per_block, int parallel_blocks,
typename type_k, int qkk, int qrk, dequantize_kernel_t dequantize_k>
void launch_fattn_tile_f16_V_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * V = dst->src[2];

switch (V->type) {
case GGML_TYPE_Q4_0:
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
break;
case GGML_TYPE_Q4_1:
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q4_1, QK4_1, QR4_1, dequantize_q4_1>(ctx, dst);
break;
case GGML_TYPE_Q5_0:
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q5_0, QK5_0, QR5_0, dequantize_q5_0>(ctx, dst);
break;
case GGML_TYPE_Q5_1:
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q5_1, QK5_1, QR5_1, dequantize_q5_1>(ctx, dst);
break;
case GGML_TYPE_Q8_0:
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
break;
case GGML_TYPE_F16:
launch_fattn_tile_f16_64_128<
cols_per_block, parallel_blocks, type_k, qkk, qrk, dequantize_k, half2, 2, 1, convert_f16>(ctx, dst);
break;
default:
GGML_ASSERT(false);
break;
}
}

template <int cols_per_block, int parallel_blocks>
void launch_fattn_tile_f16_K_type(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * K = dst->src[1];

switch (K->type) {
case GGML_TYPE_Q4_0:
launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, block_q4_0, QK4_0, QR4_0, dequantize_q4_0>(ctx, dst);
break;
case GGML_TYPE_Q4_1:
launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, block_q4_1, QK4_1, QR4_1, dequantize_q4_1>(ctx, dst);
break;
case GGML_TYPE_Q5_0:
launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, block_q5_0, QK5_0, QR5_0, dequantize_q5_0>(ctx, dst);
break;
case GGML_TYPE_Q5_1:
launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, block_q5_1, QK5_1, QR5_1, dequantize_q5_1>(ctx, dst);
break;
case GGML_TYPE_Q8_0:
launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, block_q8_0, QK8_0, QR8_0, dequantize_q8_0>(ctx, dst);
break;
case GGML_TYPE_F16:
launch_fattn_tile_f16_V_type<cols_per_block, parallel_blocks, half2, 2, 1, convert_f16>(ctx, dst);
break;
default:
GGML_ASSERT(false);
break;
}
}

void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
Expand All @@ -295,18 +395,18 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
launch_fattn_tile_f16_K_type<cols_per_block, parallel_blocks>(ctx, dst);
return;
}

if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
launch_fattn_tile_f16_K_type<cols_per_block, parallel_blocks>(ctx, dst);
return;
}

constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
launch_fattn_tile_f16_K_type<cols_per_block, parallel_blocks>(ctx, dst);
}
7 changes: 7 additions & 0 deletions ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int32_t precision = KQV->op_params[2];

if (true || ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
return;
}

// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) {
if (precision == GGML_PREC_DEFAULT) {
Expand Down
Loading