diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 33c8ed1da8d83..510ca6281471e 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -141,6 +141,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) @@ -270,7 +271,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -#ifdef GGML_CUDA_F16 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll @@ -283,7 +283,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { NO_DEVICE_CODE; #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -293,18 +292,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -//#else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//} +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} #if defined(GGML_USE_HIPBLAS) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index bcf27fd794aaf..1d29346c7453b 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -1,441 +1,565 @@ +#include "common.cuh" #include "fattn.cuh" #include -static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); - } - return a; -#else - GGML_UNUSED(a); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -} - -static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); - } - return x; -#else - GGML_UNUSED(x); - NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -} +#define FATTN_KQ_STRIDE 256 -#if __CUDA_ARCH__ >= CC_VOLTA -typedef nvcuda::wmma::fragment half16x16_a; -typedef nvcuda::wmma::fragment half16x16_b; -typedef nvcuda::wmma::fragment half16x16_bT; -typedef nvcuda::wmma::fragment half16x16_acc; -#endif - -// based on metal version -template // D head size, Q queries per block, C cache items per block -static __global__ void flash_attn_ext_f16( - const char * __restrict__ q, - const char * __restrict__ k, - const char * __restrict__ v, +template // D == head size +__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) +static __global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, - float scale, - int ne00, - int ne01, - int ne02, - int ne03, - int ne10, - int ne11, - int ne12, - int ne13, - int ne31, - int nb31, - int nb01, - int nb02, - int nb03, - int nb11, - int nb12, - int nb13, - int ne0, - int ne1, - int ne2, - int ne3) { -#if __CUDA_ARCH__ >= CC_VOLTA - const int warp_id = threadIdx.y; - const int lane_id = threadIdx.x; - - const int num_warps = blockDim.y; // number of warps - const int iq3 = blockIdx.z; - const int iq2 = blockIdx.y; - const int iq1 = blockIdx.x * Q; - - const int D16 = D/16; - const int Q16 = Q/16; - const int C16 = C/16; - - const int NW = WARP_SIZE; - const int SH = (C + Q); // shared memory per simdgroup in (half) - - const int T = D + num_warps*SH; // shared memory size per query in (half) - const int T2 = T/2; // shared memory size per query in (half2) - const int C2 = C/2; - const int D2 = D/2; - - extern __shared__ half __flash_attn_f16_shmem[]; - // pq - half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data - half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 - half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix - half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 - - half16x16_acc zr; - half16x16_acc lo[Q16][D16]; - - // load heads from Q to shared memory -#pragma unroll - for (int j0 = 0; j0 < Q; j0 += num_warps) { - const int j = j0 + warp_id; - if (j >= Q) { - break; - } - - const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - -#pragma unroll - for (int i0 = 0; i0 < D2; i0 += NW) { - const int i = i0 + lane_id; - if (i >= D2) { - break; - } - - if (iq1 + j < ne01) { - sq2[j*T2 + i] = __float22half2_rn(q2[i]); - } else { - sq2[j*T2 + i] = make_half2(0.0, 0.0); - } - } + half2 * __restrict__ dst_meta, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask; + + if (parallel_blocks == 1) { + Q_f2 += blockIdx.x*nb01/sizeof(float2); + maskh += blockIdx.x*ne11; } - nvcuda::wmma::fill_fragment(zr, 0.0); + const int stride_KV = nb11 / sizeof(half); + const int stride_KV2 = nb11 / sizeof(half2); - // zero out lo - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(lo[j][i], 0.0); - } - } + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nwarps*WARP_SIZE); - // zero out shared memory SH - for (int j = 0; j < Q; ++j) { - for (int i0 = 0; i0 < SH; i0 += NW) { - const int i = i0 + lane_id; - if (i >= SH) { - break; - } + __shared__ half KQ[nwarps*WARP_SIZE]; + KQ[tid] = -INFINITY; + half2 * KQ2 = (half2 *) KQ; - ss[j*T + i] = 0.0; - } - } + half kqmax = -INFINITY; + half kqsum = 0.0f; + __shared__ half kqmax_shared[WARP_SIZE]; + __shared__ half kqsum_shared[WARP_SIZE]; + if (threadIdx.y == 0) { + kqmax_shared[threadIdx.x] = -INFINITY; + kqsum_shared[threadIdx.x] = 0.0f; + } __syncthreads(); - { - half S = __float2half(0.0f); - half M[Q]; - - for (int i = 0; i < Q; ++i) { - M[i] = CUDART_MIN_DENORM_FP16; + // Convert Q to half2 and store in registers: + half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE]; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; } - // assume K and V are same shape - const int ne22 = ne12; - const int ne23 = ne13; + Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y); + } - const int nb21 = nb11; - const int nb22 = nb12; - const int nb23 = nb13; + half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. - // broadcast - const int rk2 = ne02/ne12; - const int rk3 = ne03/ne13; + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + half kqmax_new = kqmax; +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; - const int rv2 = ne02/ne22; - const int rv3 = ne03/ne23; + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } - // k indices - const int ik2 = iq2 / rk2; - const int ik3 = iq3 / rk3; + half2 sum2 = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) { + break; + } - // v indices - const int iv2 = iq2 / rv2; - const int iv3 = iq3 / rv3; + const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + } - // load the queries from shared memory into local memory - half16x16_a mq[Q16][D16]; - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); + sum2 = warp_reduce_sum(sum2); + half sum = __low2half(sum2) + __high2half(sum2); + sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); + kqmax_new = __hmax(kqmax_new, sum); + if (threadIdx.x == 0) { + KQ[i_KQ] = sum; } } - // pointer to the mask - const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; - - // prepare diagonal scale matrix - half16x16_b mscale; - for (int i = 0; i < 16; ++i) { - ss[i*T + i] = __float2half(scale); + kqmax_new = warp_reduce_max(kqmax_new); + if (threadIdx.x == 0) { + kqmax_shared[threadIdx.y] = kqmax_new; } - nvcuda::wmma::load_matrix_sync(mscale, ss, T); - - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { - const int ic = ic0 + warp_id*C; - if (ic >= ne11) { - break; - } - - // Q*K^T - { -#pragma unroll - for (int cc = 0; cc < C16; ++cc) { - half16x16_acc mqk[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::fill_fragment(mqk[j], 0); - } - - const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - for (int i = 0; i < D16; ++i) { - half16x16_bT mk; // transposed key - nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + __syncthreads(); + kqmax_new = kqmax_shared[threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); - } - } + const half KQ_max_scale = hexp(kqmax - kqmax_new); + kqmax = kqmax_new; - // mqk = mqk*scale + mask - for (int j = 0; j < Q16; ++j) { - half16x16_a mqka; - half16x16_acc mm; + const half val = hexp(KQ[tid] - kqmax); + kqsum = kqsum*KQ_max_scale + val; + KQ[tid] = val; - if (mp) { - nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); - } + VKQ *= __half2half2(KQ_max_scale); - // convert accumulator to matrix_a - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + __syncthreads(); - nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); - } + if (tid < D) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; } + + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; } + } + } - // used to detect blocks full of -INF - half2 smax = make_half2(-INFINITY, -INFINITY); + if (tid >= D) { + kqsum = 0.0f; + } - // online softmax - for (int j = 0; j < Q; ++j) { - const half m = M[j]; + kqsum = warp_reduce_sum(kqsum); + if (threadIdx.x == 0) { + kqsum_shared[threadIdx.y] = kqsum; + } + __syncthreads(); + kqsum = kqsum_shared[threadIdx.x]; + kqsum = warp_reduce_sum(kqsum); - for (int p0 = 0; p0 < C2; p0 += NW) { - const int p = p0 + lane_id; + if (tid >= D) { + return; + } - const half2 s = ss2[j*T2 + p]; + if (parallel_blocks == 1) { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; + } else { + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)); - smax = __hmax2(smax, s); - M[j] = __hmax(M[j], __hmax(s.x, s.y)); - } + if (tid == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2(kqmax, kqsum); + } + } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} - M[j] = warp_reduce_max(M[j]); +template // D == head size, VKQ_stride == num VKQ rows calculated in parallel +__launch_bounds__(nwarps*WARP_SIZE, 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + half2 * __restrict__ dst_meta, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c; + + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half2 * mask2 = (half2 *) mask; + + if (parallel_blocks == 1) { + Q_f += blockIdx.x * ncols*nb01/sizeof(float); + mask2 += blockIdx.x * ncols*ne11/2; + } - // local sum - half2 ls = make_half2(0.0f, 0.0f); - half2 M2 = make_half2(M[j], M[j]); + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); - for (int p0 = 0; p0 < C2; p0 += NW) { - const int p = p0 + lane_id; + frag_b Q_b[D/16][ncols/frag_n]; - const half2 s = ss2[j*T2 + p]; + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + half2 * KQ2 = (half2 *) KQ; - const half2 vs = h2exp(s - M2); + half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}}; + half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}}; - ls += vs; + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + } + } - // the P matrix from the paper (Q rows, C columns) - ss2[j*T2 + p] = vs; - } + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + if (parallel_blocks == 1) { + KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } else { + KQ[j*D_padded + i] = j == 0 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } + } + } - ls = warp_reduce_sum(ls); + __syncthreads(); - const half ms = hexp(m - M[j]); + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + } + } - // create a QxQ diagonal matrix for rescaling the output - if (lane_id == j) { - ss[j*T + C + j] = ms; + __syncthreads(); - S = S*ms + ls.x + ls.y; + // Iterate over ne11 == previous tokens: + const int k_start = parallel_blocks == 1 ? 0 : blockIdx.x*FATTN_KQ_STRIDE; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + frag_c KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + } +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); } } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); - smax = warp_reduce_max(smax); + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; - // skip -INF blocks - if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { - continue; + half2 KQ_max_new = KQ_max[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]); } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); + KQ_max[j0/nwarps] = KQ_max_new; - // O = diag(ms)*O - for (int j = 0; j < Q16; ++j) { - half16x16_a mm; - half16x16_b lob; + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + half2 val = KQ2[j*(kqs_padded/2) + k]; + val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + val = h2exp(val - KQ_max[j0/nwarps]); + KQ_rowsum_add += val; + KQ2[j*(kqs_padded/2) + k] = val; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[j0/nwarps] = KQ_max_scale[j0/nwarps]*KQ_rowsum[j0/nwarps] + KQ_rowsum_add; + } - for (int i = 0; i < D16; ++i) { - // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + __syncthreads(); - nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); - } + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*kqs_padded + k, + kqs_padded); } + } - // restore zeros - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + frag_c VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); } - // O = O + (Q*K^T)*V - { - for (int cc = 0; cc < C16; ++cc) { - const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); - - half16x16_b mv[D16]; - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); - } - - half16x16_a ms[Q16]; - for (int j = 0; j < Q16; ++j) { - nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); - } - - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); - } - } +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); } } } - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - if (lane_id < Q) { - ss[lane_id*T + 0] = S; - ss[lane_id*T + 1] = M[lane_id]; - } - } - - // reduce the warps sequentially - for (int sg = 1; sg < num_warps; ++sg) { __syncthreads(); - // each simdgroup stores its output to shared memory, reusing sq - if (warp_id == sg) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - } + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); } } __syncthreads(); - // the first simdgroup accumulates the results from the other simdgroups - if (warp_id == 0) { - for (int j = lane_id; j < Q; j += NW) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*SH + 0]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*SH + 1]; + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add; + } + } - const half M = __hmax(M0, M1); + __syncthreads(); + } - const half ms0 = hexp(M0 - M); - const half ms1 = hexp(M1 - M); + if (parallel_blocks == 1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + if (ncols*blockIdx.x + j >= ne01) { + return; + } + const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + dst[D*gridDim.y*(ncols*blockIdx.x + j) + D*blockIdx.y + i] = __half2float(VKQ[j*D_padded + i]) / KQ_rowsum_j; + } + } + } else { +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nwarps*WARP_SIZE) { + const int i = i0 + threadIdx.y*WARP_SIZE + threadIdx.x; + if (i0 + nwarps*WARP_SIZE > D && i >= D) { + return; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + i] = VKQ[i]; + } - const half S = S0*ms0 + S1*ms1; + if (threadIdx.y == 0 && threadIdx.x == 0) { + dst_meta[blockIdx.y*parallel_blocks + blockIdx.x] = make_half2( + __low2half(KQ_max[0]), __low2half(KQ_rowsum[0]) + __high2half(KQ_rowsum[0])); + } + } +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +} - ss[j*T + 0] = S; - ss[j*T + 1] = M; +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const half2 * __restrict__ VKQ_meta, + float * __restrict__ dst) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL - ss[j*T + C + j ] = ms0; - ss[j*T + C + j + sg*SH] = ms1; - } + const int tid = threadIdx.x; + __builtin_assume(tid < D); - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int j = 0; j < Q16; ++j) { - half16x16_a ms0; - half16x16_a ms1; - half16x16_b t; - half16x16_acc t2; + __shared__ half2 meta[parallel_blocks]; + if (tid < parallel_blocks) { + meta[threadIdx.x] = VKQ_meta[blockIdx.y*parallel_blocks + tid]; + } - nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); - nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); + __syncthreads(); - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); - nvcuda::wmma::mma_sync(t2, ms1, t, zr); + half kqmax = __low2half(meta[0]); +#pragma unroll + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = __hmax(kqmax, __low2half(meta[l])); + } - // convert accumulator to matrix_b - nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + float KQ_max_scale = hexp(__low2half(meta[l]) - kqmax); - nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); - } - } - } + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_denominator += KQ_max_scale * __high2float(meta[l]); } - // store result to shared memory (reuse sq) - if (warp_id == 0) { - for (int j = 0; j < Q16; ++j) { - for (int i = 0; i < D16; ++i) { - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - } - } - } + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; +#else + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} - // final rescale with 1/S and store to global memory - if (warp_id == 0) { - for (int j = 0; j < Q && iq1 + j < ne01; ++j) { - const half S = ss[j*T + 0]; +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} - for (int i0 = 0; i0 < D; i0 += NW) { - const int i = i0 + lane_id; - if (i >= D) { - break; - } +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); - dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); - } - } - } -#else - NO_DEVICE_CODE; -#endif +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; } +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +#define FATTN_SWITCH_CASE(D, ncols, nwarps) \ + case ncols: { \ + constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \ + flash_attn_ext_f16 \ + <<>> ( \ + (const char *) Q->data, \ + (const char *) K->data, \ + (const char *) V->data, \ + mask ? ((const char *) mask->data) : nullptr, \ + (float *) KQV->data, nullptr, \ + scale, \ + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \ + K->ne[0], K->ne[1], K->ne[2], K->ne[3], \ + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \ + Q->nb[1], Q->nb[2], Q->nb[3], \ + K->nb[1], K->nb[2], K->nb[3], \ + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \ + ); \ + } \ + break; \ + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -461,133 +585,254 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); -#define NQPB 16 -#define NCPW 128 + if (Q->ne[1] == 1) { + constexpr int parallel_blocks = 4; - const int nqpb = NQPB; // queries per block - const int ncpw = NCPW; // cache values per warp (does not work for other values) + ggml_cuda_pool_alloc dst_tmp(ctx.pool()); + ggml_cuda_pool_alloc dst_tmp_meta(ctx.pool()); - GGML_ASSERT(NQPB <= 32); + const int nwarps = (Q->ne[0] + WARP_SIZE - 1) / WARP_SIZE; + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const int shmem = 0; - const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? - // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why - const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; + // Performance of the vector kernel is very bad for head sizes 80 and 112, use the tensor core kernel instead: + constexpr int nwarps_tc = 4; + constexpr dim3 block_dim_tc(WARP_SIZE, nwarps_tc, 1); - dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); - dim3 block_dim(32, nwarps, 1); + const dim3 blocks_num_combine(1, blocks_num.y, blocks_num.z); + const dim3 block_dim_combine(Q->ne[0], 1, 1); + const int shmem_combine = 0; - const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } - // increase shared memory limit to 96KB - //const size_t shmem_max = 96*1024; - //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + switch (Q->ne[0]) { + case 64: + flash_attn_vec_ext_f16<64, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<64, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 80: + flash_attn_ext_f16<80, 16, nwarps_tc, get_VKQ_stride(80, nwarps_tc, 16), parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<80, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 96: + flash_attn_vec_ext_f16<96, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<96, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 112: + flash_attn_vec_ext_f16<112, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<112, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 128: + flash_attn_vec_ext_f16<128, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<128, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + case 256: + flash_attn_vec_ext_f16<256, parallel_blocks> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? ((const char *) mask->data) : nullptr, // Mask + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + if (parallel_blocks == 1) { + break; + } + CUDA_CHECK(cudaGetLastError()); + flash_attn_combine_results<256, parallel_blocks> + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + break; + default: + GGML_ASSERT(false); + break; + } + CUDA_CHECK(cudaGetLastError()); + return; + } + + int cols_per_block; + if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; + } + constexpr int nwarps = 4; + const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const size_t shmem = 0; switch (Q->ne[0]) { - case 64: - flash_attn_ext_f16<64, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 80: - flash_attn_ext_f16<80, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 96: - flash_attn_ext_f16<96, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 112: - flash_attn_ext_f16<112, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 128: - flash_attn_ext_f16<128, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; - case 256: - flash_attn_ext_f16<256, NQPB, NCPW> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? (const char *) mask->data : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + case 64: switch (cols_per_block) { + FATTN_SWITCH_CASE(64, 8, nwarps); + FATTN_SWITCH_CASE(64, 16, nwarps); + FATTN_SWITCH_CASE(64, 32, nwarps); + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 80: switch (cols_per_block) { + // FATTN_SWITCH_CASE(80, 8, nwarps); + FATTN_SWITCH_CASE(80, 16, nwarps); + FATTN_SWITCH_CASE(80, 32, nwarps); + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 96: switch (cols_per_block) { + FATTN_SWITCH_CASE(96, 8, nwarps); + FATTN_SWITCH_CASE(96, 16, nwarps); + FATTN_SWITCH_CASE(96, 32, nwarps); + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 112: switch (cols_per_block) { + // FATTN_SWITCH_CASE(112, 8, nwarps); + FATTN_SWITCH_CASE(112, 16, nwarps); + FATTN_SWITCH_CASE(112, 32, nwarps); + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 128: switch (cols_per_block) { + FATTN_SWITCH_CASE(128, 8, nwarps); + FATTN_SWITCH_CASE(128, 16, nwarps); + FATTN_SWITCH_CASE(128, 32, nwarps); + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; + case 256: switch (cols_per_block) { + FATTN_SWITCH_CASE(256, 8, nwarps); + FATTN_SWITCH_CASE(256, 16, nwarps); + FATTN_SWITCH_CASE(256, 32, nwarps); + default: + fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); + GGML_ASSERT(false); + break; + } break; default: + GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); }