From f1a35d05bf31759a70ab27b5114aecdb550633f7 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sat, 1 Apr 2023 17:27:26 -0700 Subject: [PATCH 01/24] init --- csrc/attention_kernels.cu | 242 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 5b24120eadfb..1f71b7a17fb8 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -252,6 +252,248 @@ __global__ void single_query_cached_kv_attention_kernel( } } + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__device__ void multi_query_cached_kv_attention_kernel_1xN_( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq) { + constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int seq_idx = blockIdx.y; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or comput 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + Q_vec q_vecs[NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 logits and accumulation. + float *logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = thread_group_idx % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + K_vec k_vecs[NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { + const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + + head_idx * HEAD_SIZE * BLOCK_SIZE + + physical_block_offset * x; + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + const bool mask = token_idx >= context_len; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); + using V_vec = typename Vec::Type; + using L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec = *reinterpret_cast(logits + token_idx); + + const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + + head_idx * HEAD_SIZE * BLOCK_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + accs[i] += dot(logits_vec, cast_to_float(v_vec)); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + convert_from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + } // namespace cacheflow #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ From 9341034e3b87a0105bcdb3ee3b327d273e98d444 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 03:20:46 -0700 Subject: [PATCH 02/24] update --- csrc/attention_kernels.cu | 94 +++++++++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 14 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 1f71b7a17fb8..afb1110bf85d 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -265,8 +265,8 @@ __device__ void multi_query_cached_kv_attention_kernel_1xN_( const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] + const int __restrict__ context_len, // [num_seqs] const int max_num_blocks_per_seq) { constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; @@ -318,35 +318,51 @@ __device__ void multi_query_cached_kv_attention_kernel_1xN_( constexpr int x = 16 / sizeof(scalar_t); float qk_max = -FLT_MAX; - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - const int context_len = context_lens[seq_idx]; + // const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + // const int context_len = context_lens[seq_idx]; const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + // constexpr int MAX_CONTEXT_LEN = 4096; // FIXME(suquark): make this configurable + // __shared__ int physical_block_numbers[MAX_CONTEXT_LEN / BLOCK_SIZE]; + + // int n_blocks_to_load = (num_blocks - warp_idx - 1) / NUM_WARPS + 1; + // if (thread_idx < n_blocks_to_load) { + // physical_block_numbers[thread_idx] = block_table[warp_idx + thread_idx * NUM_WARPS]; + // } + // for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + // // TODO(suquark): we start physical_block_number to shared memory. + // const int physical_block_number = block_table[block_idx]; + // } + + // TODO(suquark): we may increase the share memory size to reduce synchronization. + __shared__ K_vec k_vecs[NUM_VECS_PER_THREAD]; + // Iterate over the key blocks. // Each warp fetches a block of keys for each iteration. // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + // TODO(suquark): we start physical_block_number to shared memory. const int physical_block_number = block_table[block_idx]; const int physical_block_offset = thread_group_idx % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - // Load a key to registers. + // Load a key to shared memory. // Each thread in a thread group has a different part of the key. // For example, if the the thread group size is 4, then the first thread in the group // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th // vectors of the key, and so on. - K_vec k_vecs[NUM_VECS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { - const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + head_idx * HEAD_SIZE * BLOCK_SIZE + physical_block_offset * x; - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + // TODO(suquark): currently, gridDim.x = num_heads > NUM_VECS_PER_THREAD. but it is not always true. + if (thread_idx < NUM_VECS_PER_THREAD) { + const int vec_idx = thread_group_offset + thread_idx * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[thread_idx] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); } + __syncthreads(); // Compute dot product. // This includes a reduction across the threads in the same thread group. @@ -423,13 +439,26 @@ __device__ void multi_query_cached_kv_attention_kernel_1xN_( const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + head_idx * HEAD_SIZE * BLOCK_SIZE; + + // TODO(suquark): we may reuse the k_vecs shared memory here. + // TODO(suquark): We may use matrix multiplication here. + __shared__ V_vec k_vecs[NUM_ROWS_PER_THREAD]; + + // TODO(suquark): currently, gridDim.x = num_heads > NUM_VECS_PER_THREAD. but it is not always true. + if (thread_idx < NUM_ROWS_PER_THREAD) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + thread_idx * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + k_vecs[thread_idx] = *reinterpret_cast(k_ptr + offset); + } + } + __syncthreads(); + // TODO(suquark): We may use matrix multiplication here. #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { - const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec = *reinterpret_cast(v_ptr + offset); - accs[i] += dot(logits_vec, cast_to_float(v_vec)); + accs[i] += dot(logits_vec, cast_to_float(k_vecs[i])); } } } @@ -494,6 +523,43 @@ __device__ void multi_query_cached_kv_attention_kernel_1xN_( } } +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void multi_query_cached_kv_attention_kernel_1xN_( + const int* cu_query_lens, + const int num_queries, // len(cu_query_lens) - 1 + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq) { + for (int i = 0; i < num_queries; i++) { + const int start_idx = cu_query_lens[i]; + const int end_idx = cu_query_lens[i + 1]; + const int query_len = end_idx - start_idx; + const int context_len = context_lens[i]; + const int* block_table = block_tables + i * max_num_blocks_per_seq; + + multi_query_cached_kv_attention_kernel_1xN_< + scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( + out, + q, + k_cache, + v_cache, + scale, + block_table, + context_len, + max_num_blocks_per_seq); + } + + } // namespace cacheflow #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ From 479b5e29c983b1cc208ea90b244792683d7dd4b0 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 14:19:07 -0700 Subject: [PATCH 03/24] update --- csrc/attention_kernels.cu | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index afb1110bf85d..0178baf06d65 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -547,10 +547,14 @@ __global__ void multi_query_cached_kv_attention_kernel_1xN_( const int context_len = context_lens[i]; const int* block_table = block_tables + i * max_num_blocks_per_seq; + const scalar_t* query_ptr = q + start_idx * num_heads * HEAD_SIZE; + scalar_t* out_ptr = out + start_idx * num_heads * HEAD_SIZE; + // NOTE: we do not need to adjust the kv cache, since the block table is + // already adjusted. multi_query_cached_kv_attention_kernel_1xN_< scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( - out, - q, + out_ptr, + query_ptr, k_cache, v_cache, scale, @@ -559,7 +563,6 @@ __global__ void multi_query_cached_kv_attention_kernel_1xN_( max_num_blocks_per_seq); } - } // namespace cacheflow #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ From b0391ba977b0f0c1e95f4c34e93117488769a2c9 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 15:40:17 -0700 Subject: [PATCH 04/24] update --- csrc/attention_kernels.cu | 199 ++++++++++++++++++++++++++++++++++---- 1 file changed, 178 insertions(+), 21 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 0178baf06d65..9eb1347eca42 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -529,44 +529,36 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS> -__global__ void multi_query_cached_kv_attention_kernel_1xN_( - const int* cu_query_lens, - const int num_queries, // len(cu_query_lens) - 1 +__global__ void multi_query_cached_kv_attention_kernel( + const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_prompts] const int max_num_blocks_per_seq) { - for (int i = 0; i < num_queries; i++) { - const int start_idx = cu_query_lens[i]; - const int end_idx = cu_query_lens[i + 1]; - const int query_len = end_idx - start_idx; - const int context_len = context_lens[i]; - const int* block_table = block_tables + i * max_num_blocks_per_seq; - - const scalar_t* query_ptr = q + start_idx * num_heads * HEAD_SIZE; - scalar_t* out_ptr = out + start_idx * num_heads * HEAD_SIZE; - // NOTE: we do not need to adjust the kv cache, since the block table is - // already adjusted. - multi_query_cached_kv_attention_kernel_1xN_< + const int seq_idx = blockIdx.y; + const int prompt_idx = seq_prompt_mapping[seq_idx]; + const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; + const int context_len = context_lens[prompt_idx]; + multi_query_cached_kv_attention_kernel_1xN_< scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( - out_ptr, - query_ptr, + out, + q, k_cache, v_cache, scale, block_table, context_len, max_num_blocks_per_seq); - } +} } // namespace cacheflow #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ - cacheflow::single_query_cached_kv_attention_kernel \ + cacheflow::multi_query_cached_kv_attention_kernel \ <<>>( \ out_ptr, \ query_ptr, \ @@ -709,4 +701,169 @@ void single_query_cached_kv_attention( } } + +#define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + cacheflow::single_query_cached_kv_attention_kernel \ + <<>>( \ + seq_prompt_mapping_ptr, \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq); + + +// TODO(woosuk): Tune NUM_THREADS. +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void multi_query_cached_kv_attention_launcher( + const int* seq_prompt_mapping_ptr, + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + case 32: + LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + break; + case 64: + LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + break; + case 80: + LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + break; + case 96: + LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + break; + case 128: + LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + break; + case 160: + LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + break; + case 192: + LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + break; + case 256: + LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; + default: + assert(false); + break; + } +} + +void multi_query_cached_kv_attention( + torch::Tensor& cu_query_lens, + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len) { + + int num_queries = cu_query_lens.size(0) - 1; + int num_seqs = query.size(0); + + int seq_prompt_mapping[num_queries]; + for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { + if (i >= cu_query_lens[query_cursor + 1]) { + ++query_cursor; + } + seq_prompt_mapping[query_cursor] = i; + } + + // TODO(woosuk): Support BF16. + if (query.element_size() == 2) { + // Half. + if (block_size == 8) { + multi_query_cached_kv_attention_launcher( + seq_prompt_mapping, + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); + } else if (block_size == 16) { + multi_query_cached_kv_attention_launcher( + seq_prompt_mapping, + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); + } else { + assert(false); + } + } else if (query.element_size() == 4) { + // Float. + if (block_size == 8) { + multi_query_cached_kv_attention_launcher( + seq_prompt_mapping, + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); + } else if (block_size == 16) { + multi_query_cached_kv_attention_launcher( + seq_prompt_mapping, + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); + } else { + assert(false); + } + } else { + assert(false); + } +} + #undef WARP_SIZE From aa0c2e24cc8c4e295e5aef5d6ee4e201d2890225 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 15:50:04 -0700 Subject: [PATCH 05/24] add binding and tests --- csrc/attention.cpp | 16 +++++ tests/kernels/attention.py | 139 +++++++++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+) diff --git a/csrc/attention.cpp b/csrc/attention.cpp index bb2766c1d6b6..57dff9dc0b2b 100644 --- a/csrc/attention.cpp +++ b/csrc/attention.cpp @@ -11,9 +11,25 @@ void single_query_cached_kv_attention( int block_size, int max_context_len); +void multi_query_cached_kv_attention( + torch::Tensor& cu_query_lens, + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "single_query_cached_kv_attention", &single_query_cached_kv_attention, "Compute the attention between an input query and the cached key/value tensors"); + m.def( + "multi_query_cached_kv_attention", + &multi_query_cached_kv_attention, + "Compute the attention between multiple input queries and the cached key/value tensors"); } diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index b6766e1eddc2..4d39d9075db4 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -97,6 +97,61 @@ def ref_multi_query_kv_attention( return ref_output +def ref_multi_query_cached_kv_attention( + cu_query_lens: List[int], + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + num_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + scale = 1.0 / (head_size ** 0.5) + + num_queries = len(cu_query_lens) - 1 + ref_outputs = [] + for i in range(num_queries): + start_idx = cu_query_lens[i] + end_idx = cu_query_lens[i + 1] + query_len = end_idx - start_idx + context_len = int(context_lens[i]) + block_table = block_tables[i] + + # Create attention mask + attn_mask = torch.triu( + torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5 + attn_mask = attn_mask.to(dtype=dtype, device='cuda') + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + keys, + values, + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + def test_single_query_cached_kv_attention( num_tokens: int, num_heads: int, @@ -200,6 +255,75 @@ def test_multi_query_kv_attention( assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) +def test_multi_query_cached_kv_attention( + num_queries: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, +) -> None: + query_lens = random.sample(range(1, MAX_SEQ_LEN), num_queries) + cu_query_lens = [0] + for query_len in query_lens: + cu_query_lens.append(cu_query_lens[-1] + query_len) + num_total_tokens = cu_query_lens[-1] + + query = torch.randn( + num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda') + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_block_shape = (num_heads, head_size // x, block_size, x) + key_cache = torch.randn( + size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') + value_block_shape = (num_heads, head_size, block_size) + value_cache = torch.randn( + size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') + + cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda') + context_lens = [ + query_len + random.randint(0, MAX_SEQ_LEN - query_len) + for query_len in query_lens + ] + max_context_len = max(context_lens) + + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_queries): + block_table = [ + random.randint(0, num_blocks - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') + + scale = float(1.0 / (head_size ** 0.5)) + output = torch.empty_like(query) + + attention_ops.multi_query_cached_kv_attention( + cu_query_lens, + output, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + ) + + ref_output = ref_multi_query_cached_kv_attention( + cu_query_lens, + query, + key_cache, + value_cache, + block_tables, + context_lens, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + @torch.inference_mode() def test_attention(seed: int) -> None: # NOTE(woosuk): Even when the seed is fixed, there is a chance that @@ -221,6 +345,21 @@ def test_attention(seed: int) -> None: dtype=dtype, ) + for dtype in [torch.half, torch.float]: + for block_size in [8, 16]: + for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: + print(f'Testing multi_query_cached_kv_attention with ' + f'dtype={dtype}, block_size={block_size}, ' + f'head_size={head_size}') + test_multi_query_cached_kv_attention( + num_queries=11, + num_heads=3, + head_size=head_size, + block_size=block_size, + num_blocks=1024, + dtype=dtype, + ) + # NOTE(woosuk): FlashAttention does not support FP32. for dtype in [torch.half]: # NOTE(woosuk): FlashAttention does not support head_size > 128. From a32161510352b915c60ed9e78c683e2c0be8b34f Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 16:41:49 -0700 Subject: [PATCH 06/24] fix --- csrc/attention_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 9eb1347eca42..10a732eb45d5 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -558,7 +558,7 @@ __global__ void multi_query_cached_kv_attention_kernel( } // namespace cacheflow #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ - cacheflow::multi_query_cached_kv_attention_kernel \ + cacheflow::single_query_cached_kv_attention_kernel \ <<>>( \ out_ptr, \ query_ptr, \ @@ -703,7 +703,7 @@ void single_query_cached_kv_attention( #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ - cacheflow::single_query_cached_kv_attention_kernel \ + cacheflow::multi_query_cached_kv_attention_kernel \ <<>>( \ seq_prompt_mapping_ptr, \ out_ptr, \ From 6ad9ef43ed71c6cbc19f97ca8bae23bd0bd0eda5 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 16:44:40 -0700 Subject: [PATCH 07/24] fix --- csrc/attention_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 10a732eb45d5..f46bc91ef7cc 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -265,8 +265,8 @@ __device__ void multi_query_cached_kv_attention_kernel_1xN_( const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const float scale, - const int __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] - const int __restrict__ context_len, // [num_seqs] + const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] + const int context_len, // [num_seqs] const int max_num_blocks_per_seq) { constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; From 178b00ca4857734603a23213b2e4694b38d1828e Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 16:49:59 -0700 Subject: [PATCH 08/24] fix --- csrc/attention_kernels.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index f46bc91ef7cc..dd07ed9be759 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -442,14 +442,14 @@ __device__ void multi_query_cached_kv_attention_kernel_1xN_( // TODO(suquark): we may reuse the k_vecs shared memory here. // TODO(suquark): We may use matrix multiplication here. - __shared__ V_vec k_vecs[NUM_ROWS_PER_THREAD]; + __shared__ V_vec v_vecs[NUM_ROWS_PER_THREAD]; // TODO(suquark): currently, gridDim.x = num_heads > NUM_VECS_PER_THREAD. but it is not always true. if (thread_idx < NUM_ROWS_PER_THREAD) { const int row_idx = lane / NUM_V_VECS_PER_ROW + thread_idx * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - k_vecs[thread_idx] = *reinterpret_cast(k_ptr + offset); + v_vecs[thread_idx] = *reinterpret_cast(v_ptr + offset); } } __syncthreads(); @@ -796,11 +796,12 @@ void multi_query_cached_kv_attention( int max_context_len) { int num_queries = cu_query_lens.size(0) - 1; + const int* cu_query_lens_ptr = cu_query_lens.data_ptr(); int num_seqs = query.size(0); int seq_prompt_mapping[num_queries]; for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { - if (i >= cu_query_lens[query_cursor + 1]) { + if (i >= cu_query_lens_ptr[query_cursor + 1]) { ++query_cursor; } seq_prompt_mapping[query_cursor] = i; From 987f0e9b8da7a357648f29f15d855edd656229c9 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 16:58:43 -0700 Subject: [PATCH 09/24] fix --- csrc/attention_kernels.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index dd07ed9be759..3d86374ce86f 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -431,6 +431,10 @@ __device__ void multi_query_cached_kv_attention_kernel_1xN_( accs[i] = 0.f; } + // TODO(suquark): we may reuse the k_vecs shared memory here. + // TODO(suquark): We may use matrix multiplication here. + __shared__ V_vec v_vecs[NUM_ROWS_PER_THREAD]; + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; @@ -440,10 +444,6 @@ __device__ void multi_query_cached_kv_attention_kernel_1xN_( const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + head_idx * HEAD_SIZE * BLOCK_SIZE; - // TODO(suquark): we may reuse the k_vecs shared memory here. - // TODO(suquark): We may use matrix multiplication here. - __shared__ V_vec v_vecs[NUM_ROWS_PER_THREAD]; - // TODO(suquark): currently, gridDim.x = num_heads > NUM_VECS_PER_THREAD. but it is not always true. if (thread_idx < NUM_ROWS_PER_THREAD) { const int row_idx = lane / NUM_V_VECS_PER_ROW + thread_idx * NUM_ROWS_PER_ITER; @@ -458,7 +458,7 @@ __device__ void multi_query_cached_kv_attention_kernel_1xN_( for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { - accs[i] += dot(logits_vec, cast_to_float(k_vecs[i])); + accs[i] += dot(logits_vec, cast_to_float(v_vecs[i])); } } } From 11319610d5139797a05ea0c27ca894ecba79f386 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 17:55:15 -0700 Subject: [PATCH 10/24] fix --- tests/kernels/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index 4d39d9075db4..b35a36f4fe67 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -285,6 +285,7 @@ def test_multi_query_cached_kv_attention( for query_len in query_lens ] max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] From 5842d76a98f4639cc91bae66e5993f7b80a533b2 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 17:57:03 -0700 Subject: [PATCH 11/24] update --- tests/kernels/attention.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index b35a36f4fe67..9beb2c312e18 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -327,18 +327,14 @@ def test_multi_query_cached_kv_attention( @torch.inference_mode() def test_attention(seed: int) -> None: - # NOTE(woosuk): Even when the seed is fixed, there is a chance that - # the test fails due to the precision issue. Re-run the test if it fails. - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) for dtype in [torch.half, torch.float]: for block_size in [8, 16]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: - print(f'Testing single_query_cached_kv_attention with ' + print(f'Testing multi_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' f'head_size={head_size}') - test_single_query_cached_kv_attention( - num_tokens=37, + test_multi_query_cached_kv_attention( + num_queries=11, num_heads=3, head_size=head_size, block_size=block_size, @@ -346,14 +342,18 @@ def test_attention(seed: int) -> None: dtype=dtype, ) + # NOTE(woosuk): Even when the seed is fixed, there is a chance that + # the test fails due to the precision issue. Re-run the test if it fails. + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) for dtype in [torch.half, torch.float]: for block_size in [8, 16]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: - print(f'Testing multi_query_cached_kv_attention with ' + print(f'Testing single_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' f'head_size={head_size}') - test_multi_query_cached_kv_attention( - num_queries=11, + test_single_query_cached_kv_attention( + num_tokens=37, num_heads=3, head_size=head_size, block_size=block_size, From 0ff95f87366dd351f6254fb2a69b805447922c20 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 19:09:04 -0700 Subject: [PATCH 12/24] fix cuda memory error --- csrc/attention_kernels.cu | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 3d86374ce86f..e31dadaadbf8 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -799,14 +799,17 @@ void multi_query_cached_kv_attention( const int* cu_query_lens_ptr = cu_query_lens.data_ptr(); int num_seqs = query.size(0); - int seq_prompt_mapping[num_queries]; + torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32)); + auto accessor = cpu_tensor.accessor(); for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { if (i >= cu_query_lens_ptr[query_cursor + 1]) { ++query_cursor; } - seq_prompt_mapping[query_cursor] = i; + accessor[i] = query_cursor; } + torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA); + // TODO(woosuk): Support BF16. if (query.element_size() == 2) { // Half. From 45093b4845b3f93ee23ded37d8257df9db194d7c Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 19:12:32 -0700 Subject: [PATCH 13/24] fix --- csrc/attention_kernels.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index e31dadaadbf8..ee968192f569 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -722,7 +722,7 @@ template< int BLOCK_SIZE, int NUM_THREADS = 128> void multi_query_cached_kv_attention_launcher( - const int* seq_prompt_mapping_ptr, + torch::Tensor& seq_prompt_mapping, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -736,6 +736,7 @@ void multi_query_cached_kv_attention_launcher( int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); + int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr(); T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); From ae7c1de086c59393e27835cb9bbb6f9d28d57010 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 20:57:28 -0700 Subject: [PATCH 14/24] update --- csrc/attention_kernels.cu | 244 +++++++++++++++++++++++++++++++++++++- 1 file changed, 243 insertions(+), 1 deletion(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index ee968192f569..2573241074ae 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -253,6 +253,247 @@ __global__ void single_query_cached_kv_attention_kernel( } +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] + const float scale, + const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] + const int context_len, // [num_seqs] + const int max_num_blocks_per_seq) { + constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int seq_idx = blockIdx.y; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or comput 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + Q_vec q_vecs[NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 logits and accumulation. + float *logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = thread_group_idx % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + K_vec k_vecs[NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { + const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + + head_idx * HEAD_SIZE * BLOCK_SIZE + + physical_block_offset * x; + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + const bool mask = token_idx >= context_len; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); + using V_vec = typename Vec::Type; + using L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec = *reinterpret_cast(logits + token_idx); + + const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + + head_idx * HEAD_SIZE * BLOCK_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + accs[i] += dot(logits_vec, cast_to_float(v_vec)); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + convert_from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + + + // Grid: (num_heads, num_seqs). template< typename scalar_t, @@ -543,7 +784,8 @@ __global__ void multi_query_cached_kv_attention_kernel( const int prompt_idx = seq_prompt_mapping[seq_idx]; const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; const int context_len = context_lens[prompt_idx]; - multi_query_cached_kv_attention_kernel_1xN_< + // multi_query_cached_kv_attention_kernel_1xN_< + multi_query_cached_kv_attention_kernel_unoptimized_1xN_< scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( out, q, From 2ff01e5fc55bfc2dcd98bdf7fac87aed5ffabe5c Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 2 Apr 2023 21:01:08 -0700 Subject: [PATCH 15/24] update --- csrc/attention_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 2573241074ae..502812c5388a 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -259,7 +259,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS> -__global__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( +__device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] From c46376a7a30f2d64e920959266d076d5490ce55a Mon Sep 17 00:00:00 2001 From: Siyuan Date: Mon, 3 Apr 2023 03:33:35 -0700 Subject: [PATCH 16/24] fix attention mask --- csrc/attention_kernels.cu | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 502812c5388a..0265a29b9771 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -262,6 +262,7 @@ template< __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const int query_len, const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const float scale, @@ -349,8 +350,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( // Compute dot product. // This includes a reduction across the threads in the same thread group. const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - const bool mask = token_idx >= context_len; - + const bool mask = token_idx >= context_len - query_len; + if (thread_group_offset == 0) { // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. @@ -771,6 +772,7 @@ template< int BLOCK_SIZE, int NUM_THREADS> __global__ void multi_query_cached_kv_attention_kernel( + const int* cu_query_lens, // [num_prompts+1] const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -782,6 +784,7 @@ __global__ void multi_query_cached_kv_attention_kernel( const int max_num_blocks_per_seq) { const int seq_idx = blockIdx.y; const int prompt_idx = seq_prompt_mapping[seq_idx]; + const int query_len = cu_query_lens[prompt_idx + 1] - cu_query_lens[prompt_idx]; const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; const int context_len = context_lens[prompt_idx]; // multi_query_cached_kv_attention_kernel_1xN_< @@ -789,6 +792,7 @@ __global__ void multi_query_cached_kv_attention_kernel( scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( out, q, + query_len, k_cache, v_cache, scale, @@ -945,8 +949,9 @@ void single_query_cached_kv_attention( #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ - cacheflow::multi_query_cached_kv_attention_kernel \ + cacheflow::multi_query_cached_kv_attention_kernel \ <<>>( \ + cu_query_lens_ptr, \ seq_prompt_mapping_ptr, \ out_ptr, \ query_ptr, \ @@ -964,6 +969,7 @@ template< int BLOCK_SIZE, int NUM_THREADS = 128> void multi_query_cached_kv_attention_launcher( + torch::Tensor& cu_query_lens, torch::Tensor& seq_prompt_mapping, torch::Tensor& out, torch::Tensor& query, @@ -978,6 +984,7 @@ void multi_query_cached_kv_attention_launcher( int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); + int* cu_query_lens_ptr = cu_query_lens.data_ptr(); int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr(); T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -1037,15 +1044,17 @@ void multi_query_cached_kv_attention( torch::Tensor& context_lens, int block_size, int max_context_len) { + + torch::Tensor query_lens = cu_query_lens.to(torch::kCPU); - int num_queries = cu_query_lens.size(0) - 1; - const int* cu_query_lens_ptr = cu_query_lens.data_ptr(); + int num_queries = query_lens.size(0) - 1; + const int* query_lens_ptr = query_lens.data_ptr(); int num_seqs = query.size(0); torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32)); auto accessor = cpu_tensor.accessor(); for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { - if (i >= cu_query_lens_ptr[query_cursor + 1]) { + if (i >= query_lens_ptr[query_cursor + 1]) { ++query_cursor; } accessor[i] = query_cursor; @@ -1058,6 +1067,7 @@ void multi_query_cached_kv_attention( // Half. if (block_size == 8) { multi_query_cached_kv_attention_launcher( + cu_query_lens, seq_prompt_mapping, out, query, @@ -1069,6 +1079,7 @@ void multi_query_cached_kv_attention( max_context_len); } else if (block_size == 16) { multi_query_cached_kv_attention_launcher( + cu_query_lens, seq_prompt_mapping, out, query, @@ -1085,6 +1096,7 @@ void multi_query_cached_kv_attention( // Float. if (block_size == 8) { multi_query_cached_kv_attention_launcher( + cu_query_lens, seq_prompt_mapping, out, query, @@ -1096,6 +1108,7 @@ void multi_query_cached_kv_attention( max_context_len); } else if (block_size == 16) { multi_query_cached_kv_attention_launcher( + cu_query_lens, seq_prompt_mapping, out, query, From 16e0238d2975ecf4da9bee3d81174f68a85947f4 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Mon, 3 Apr 2023 03:55:15 -0700 Subject: [PATCH 17/24] update --- csrc/attention_kernels.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 0265a29b9771..72e68f2ba2a5 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -262,7 +262,7 @@ template< __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const int query_len, + const int start_seq_idx, const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const float scale, @@ -350,7 +350,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( // Compute dot product. // This includes a reduction across the threads in the same thread group. const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - const bool mask = token_idx >= context_len - query_len; + const bool mask = token_idx >= context_len + (seq_idx - start_seq_idx); if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -784,7 +784,7 @@ __global__ void multi_query_cached_kv_attention_kernel( const int max_num_blocks_per_seq) { const int seq_idx = blockIdx.y; const int prompt_idx = seq_prompt_mapping[seq_idx]; - const int query_len = cu_query_lens[prompt_idx + 1] - cu_query_lens[prompt_idx]; + const int start_seq_idx = cu_query_lens[prompt_idx]; const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; const int context_len = context_lens[prompt_idx]; // multi_query_cached_kv_attention_kernel_1xN_< @@ -792,7 +792,7 @@ __global__ void multi_query_cached_kv_attention_kernel( scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( out, q, - query_len, + start_seq_idx, k_cache, v_cache, scale, From b42fad00657c8f4c3261e5a4f20bce04e1a95974 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Mon, 3 Apr 2023 04:35:31 -0700 Subject: [PATCH 18/24] fix --- csrc/attention_kernels.cu | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 72e68f2ba2a5..331b7cc83de9 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -262,7 +262,8 @@ template< __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const int start_seq_idx, + const int seq_start_idx, + const int seq_len, const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const float scale, @@ -350,7 +351,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( // Compute dot product. // This includes a reduction across the threads in the same thread group. const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - const bool mask = token_idx >= context_len + (seq_idx - start_seq_idx); + const bool mask = token_idx >= context_len - seq_len + 1 + (seq_idx - seq_start_idx); if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -784,7 +785,8 @@ __global__ void multi_query_cached_kv_attention_kernel( const int max_num_blocks_per_seq) { const int seq_idx = blockIdx.y; const int prompt_idx = seq_prompt_mapping[seq_idx]; - const int start_seq_idx = cu_query_lens[prompt_idx]; + const int seq_start_idx = cu_query_lens[prompt_idx]; + const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx; const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; const int context_len = context_lens[prompt_idx]; // multi_query_cached_kv_attention_kernel_1xN_< @@ -793,6 +795,7 @@ __global__ void multi_query_cached_kv_attention_kernel( out, q, start_seq_idx, + seq_len, k_cache, v_cache, scale, From 7879e5495115246ca80fef1a04d9b32b00a41d76 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Mon, 3 Apr 2023 04:37:28 -0700 Subject: [PATCH 19/24] fix --- csrc/attention_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 331b7cc83de9..05dfe6ce588e 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -794,7 +794,7 @@ __global__ void multi_query_cached_kv_attention_kernel( scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( out, q, - start_seq_idx, + seq_start_idx, seq_len, k_cache, v_cache, From 07f4ff846a1f6054c637b6c35e52b01497a807e9 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Mon, 3 Apr 2023 05:24:50 -0700 Subject: [PATCH 20/24] cleanup --- csrc/attention_kernels.cu | 271 -------------------------------------- 1 file changed, 271 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 05dfe6ce588e..4cf80d37ba46 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -495,277 +495,6 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( } - -// Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS> -__device__ void multi_query_cached_kv_attention_kernel_1xN_( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] - const float scale, - const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] - const int context_len, // [num_seqs] - const int max_num_blocks_per_seq) { - constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int thread_idx = threadIdx.x; - const int warp_idx = thread_idx / WARP_SIZE; - const int lane = thread_idx % WARP_SIZE; - - const int head_idx = blockIdx.x; - const int num_heads = gridDim.x; - const int seq_idx = blockIdx.y; - - // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or comput 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. - constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); - using K_vec = typename Vec::Type; - using Q_vec = typename Vec::Type; - - constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; - constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - - const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; - const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - - // Load the query to registers. - // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - Q_vec q_vecs[NUM_VECS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); - } - - // Memory planning. - extern __shared__ char shared_mem[]; - // NOTE(woosuk): We use FP32 logits and accumulation. - float *logits = reinterpret_cast(shared_mem); - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // x == THREAD_GROUP_SIZE * VEC_SIZE - // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(scalar_t); - float qk_max = -FLT_MAX; - - // const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - // const int context_len = context_lens[seq_idx]; - const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - - // constexpr int MAX_CONTEXT_LEN = 4096; // FIXME(suquark): make this configurable - // __shared__ int physical_block_numbers[MAX_CONTEXT_LEN / BLOCK_SIZE]; - - // int n_blocks_to_load = (num_blocks - warp_idx - 1) / NUM_WARPS + 1; - // if (thread_idx < n_blocks_to_load) { - // physical_block_numbers[thread_idx] = block_table[warp_idx + thread_idx * NUM_WARPS]; - // } - // for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { - // // TODO(suquark): we start physical_block_number to shared memory. - // const int physical_block_number = block_table[block_idx]; - // } - - // TODO(suquark): we may increase the share memory size to reduce synchronization. - __shared__ K_vec k_vecs[NUM_VECS_PER_THREAD]; - - // Iterate over the key blocks. - // Each warp fetches a block of keys for each iteration. - // Each thread group in a warp fetches a key from the block, and computes - // dot product with the query. - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { - // TODO(suquark): we start physical_block_number to shared memory. - const int physical_block_number = block_table[block_idx]; - const int physical_block_offset = thread_group_idx % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - - // Load a key to shared memory. - // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. - const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE - + head_idx * HEAD_SIZE * BLOCK_SIZE - + physical_block_offset * x; - // TODO(suquark): currently, gridDim.x = num_heads > NUM_VECS_PER_THREAD. but it is not always true. - if (thread_idx < NUM_VECS_PER_THREAD) { - const int vec_idx = thread_group_offset + thread_idx * THREAD_GROUP_SIZE; - const int offset1 = (vec_idx * VEC_SIZE) / x; - const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[thread_idx] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } - __syncthreads(); - - // Compute dot product. - // This includes a reduction across the threads in the same thread group. - const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - const bool mask = token_idx >= context_len; - - if (thread_group_offset == 0) { - // Store the partial reductions to shared memory. - // NOTE(woosuk): It is required to zero out the masked logits. - logits[token_idx] = mask ? 0.f : qk; - // Update the max value. - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - } - } - - // Perform reduction across the threads in the same warp to get the - // max qk value for each "warp" (not across the thread block yet). - // The 0-th thread of each thread group already has its max qk value. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - __syncthreads(); - - // TODO(woosuk): Refactor this part. - // Get the max qk value for the sequence. - qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - // Broadcast the max qk value to all threads. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Get the sum of the exp values. - float exp_sum = 0.f; - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { - float val = __expf(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - - // Compute softmax. - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - __syncthreads(); - - // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); - using V_vec = typename Vec::Type; - using L_vec = typename FloatVec::Type; - - constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; - constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; - - float accs[NUM_ROWS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - accs[i] = 0.f; - } - - // TODO(suquark): we may reuse the k_vecs shared memory here. - // TODO(suquark): We may use matrix multiplication here. - __shared__ V_vec v_vecs[NUM_ROWS_PER_THREAD]; - - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { - const int physical_block_number = block_table[block_idx]; - const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - L_vec logits_vec = *reinterpret_cast(logits + token_idx); - - const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE - + head_idx * HEAD_SIZE * BLOCK_SIZE; - - // TODO(suquark): currently, gridDim.x = num_heads > NUM_VECS_PER_THREAD. but it is not always true. - if (thread_idx < NUM_ROWS_PER_THREAD) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + thread_idx * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE) { - const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - v_vecs[thread_idx] = *reinterpret_cast(v_ptr + offset); - } - } - __syncthreads(); - // TODO(suquark): We may use matrix multiplication here. -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE) { - accs[i] += dot(logits_vec, cast_to_float(v_vecs[i])); - } - } - } - - // Perform reduction within each warp. -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; -#pragma unroll - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); - } - accs[i] = acc; - } - - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. - __syncthreads(); - - // Perform reduction across warps. - float* out_smem = reinterpret_cast(shared_mem); -#pragma unroll - for (int i = NUM_WARPS; i > 1; i /= 2) { - int mid = i / 2; - // Upper warps write to shared memory. - if (warp_idx >= mid && warp_idx < i) { - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - dst[row_idx] = accs[i]; - } - } - } - __syncthreads(); - - // Lower warps update the output. - if (warp_idx < mid) { - const float* src = &out_smem[warp_idx * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - accs[i] += src[row_idx]; - } - } - } - __syncthreads(); - } - - // Write the final output. - if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - convert_from_float(*(out_ptr + row_idx), accs[i]); - } - } - } -} - // Grid: (num_heads, num_seqs). template< typename scalar_t, From d72606a4e514887f62768181991bd77e3e711366 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Mon, 3 Apr 2023 12:04:24 -0700 Subject: [PATCH 21/24] fix --- csrc/attention_kernels.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 4cf80d37ba46..2a442403e528 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -321,6 +321,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( float qk_max = -FLT_MAX; const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx); // Iterate over the key blocks. // Each warp fetches a block of keys for each iteration. @@ -351,7 +352,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( // Compute dot product. // This includes a reduction across the threads in the same thread group. const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - const bool mask = token_idx >= context_len - seq_len + 1 + (seq_idx - seq_start_idx); + const bool mask = token_idx >= mask_boundary; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -386,7 +387,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( // Get the sum of the exp values. float exp_sum = 0.f; - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) { float val = __expf(logits[i] - qk_max); logits[i] = val; exp_sum += val; From 99840f5d0cf4ebd1155623c28e72ed66c0fe2e7d Mon Sep 17 00:00:00 2001 From: Siyuan Date: Mon, 3 Apr 2023 12:34:51 -0700 Subject: [PATCH 22/24] update --- tests/kernels/attention.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index 9beb2c312e18..95e57a91b82d 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -327,14 +327,18 @@ def test_multi_query_cached_kv_attention( @torch.inference_mode() def test_attention(seed: int) -> None: + # NOTE(woosuk): Even when the seed is fixed, there is a chance that + # the test fails due to the precision issue. Re-run the test if it fails. + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) for dtype in [torch.half, torch.float]: for block_size in [8, 16]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: - print(f'Testing multi_query_cached_kv_attention with ' + print(f'Testing single_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' f'head_size={head_size}') - test_multi_query_cached_kv_attention( - num_queries=11, + test_single_query_cached_kv_attention( + num_tokens=37, num_heads=3, head_size=head_size, block_size=block_size, @@ -342,18 +346,17 @@ def test_attention(seed: int) -> None: dtype=dtype, ) - # NOTE(woosuk): Even when the seed is fixed, there is a chance that - # the test fails due to the precision issue. Re-run the test if it fails. - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + # NOTE(siyuan): Same as above. Re-run the test if it fails. Also + # note that the test is also more likely to fail due to the much + # larger amount of tokens in the input may increase the variance. for dtype in [torch.half, torch.float]: for block_size in [8, 16]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: - print(f'Testing single_query_cached_kv_attention with ' + print(f'Testing multi_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' f'head_size={head_size}') - test_single_query_cached_kv_attention( - num_tokens=37, + test_multi_query_cached_kv_attention( + num_queries=11, num_heads=3, head_size=head_size, block_size=block_size, From 5bfa1eefea34d0edd1ca730080182fb0d8f6919b Mon Sep 17 00:00:00 2001 From: "Siyuan (Ryans) Zhuang" Date: Mon, 3 Apr 2023 15:32:00 -0700 Subject: [PATCH 23/24] Apply suggestions from code review Co-authored-by: Woosuk Kwon --- csrc/attention_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 75cf54c60132..9e1cb4ea5093 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -270,7 +270,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const float scale, const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] - const int context_len, // [num_seqs] + const int context_len, const int max_num_blocks_per_seq) { constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; @@ -539,7 +539,7 @@ __global__ void multi_query_cached_kv_attention_kernel( } // namespace cacheflow #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ - cacheflow::single_query_cached_kv_attention_kernel \ + cacheflow::single_query_cached_kv_attention_kernel \ <<>>( \ out_ptr, \ query_ptr, \ From 2d8f5c8c7f5ba0273ff07dc9462f671135802318 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Mon, 3 Apr 2023 15:37:11 -0700 Subject: [PATCH 24/24] fix comments --- csrc/attention_kernels.cu | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 9e1cb4ea5093..73c29b745309 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -255,13 +255,13 @@ __global__ void single_query_cached_kv_attention_kernel( } -// Grid: (num_heads, num_seqs). +// Grid: (num_heads, num_query_tokens). template< typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS> -__device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( +__device__ void multi_query_cached_kv_attention_kernel_unoptimized_( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const int seq_start_idx, @@ -498,7 +498,7 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_1xN_( } -// Grid: (num_heads, num_seqs). +// Grid: (num_heads, num_query_tokens). template< typename scalar_t, int HEAD_SIZE, @@ -521,8 +521,7 @@ __global__ void multi_query_cached_kv_attention_kernel( const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx; const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; const int context_len = context_lens[prompt_idx]; - // multi_query_cached_kv_attention_kernel_1xN_< - multi_query_cached_kv_attention_kernel_unoptimized_1xN_< + multi_query_cached_kv_attention_kernel_unoptimized_< scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( out, q, @@ -797,6 +796,9 @@ void multi_query_cached_kv_attention( accessor[i] = query_cursor; } + // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA) + // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving + // the mapping as an input parameter. Let's do this optimization in a later PR. torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA); // TODO(woosuk): Support BF16.