From e21845e1f036c054950ae82d2d9f8dba0a56b64c Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 5 Apr 2023 10:28:40 -0700 Subject: [PATCH 01/16] optimize optimize with shared memory better number of threads update test temp disable test update --- benchmark/benchmark_attention.py | 50 ++++-- csrc/cache.cpp | 12 ++ csrc/cache_kernels.cu | 270 +++++++++++++++++++++++++++++++ tests/kernels/attention.py | 161 ++++++++++++------ 4 files changed, 433 insertions(+), 60 deletions(-) diff --git a/benchmark/benchmark_attention.py b/benchmark/benchmark_attention.py index ac43ddb36e54..f75a0f6e58a8 100644 --- a/benchmark/benchmark_attention.py +++ b/benchmark/benchmark_attention.py @@ -6,7 +6,7 @@ from flash_attn.flash_attn_interface import _flash_attn_forward import torch -from cacheflow import attention_ops +from cacheflow import attention_ops, cache_ops def benchmark(name, f, num_warmup = 10, num_iters = 100): @@ -43,7 +43,7 @@ def benchmark_multi_query_cached_kv_attention( num_total_tokens = cu_query_lens[-1] qkv = torch.randn( num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') - query, _, _ = qkv.unbind(dim=1) + query, key, value = qkv.unbind(dim=1) # NOTE: this will not make a copy. # Create key and value cache. x = 16 // torch.tensor([], dtype=dtype).element_size() @@ -72,21 +72,53 @@ def benchmark_multi_query_cached_kv_attention( scale = float(1.0 / (head_size ** 0.5)) output = torch.empty( num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda') + + num_kv_tokens = sum(context_lens) + cu_context_lens = [0] + for context_len in context_lens: + cu_context_lens.append(cu_context_lens[-1] + context_len) + cpu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cpu') + cu_context_lens = cpu_context_lens.cuda() + ref_output = torch.empty_like(output) # Run our implementation. def run_ours(): - attention_ops.multi_query_cached_kv_attention( - cu_query_lens, - output, - query, + cache_ops.gather_cached_kv( + qkv, key_cache, value_cache, - scale, + cu_context_lens, + cpu_context_lens, block_tables, - context_len_tensor, - block_size, + ) + + _flash_attn_forward( + query, + key, + value, + ref_output, + cu_query_lens, + cu_context_lens, + max(query_lens), max_context_len, + dropout_p=0.0, + softmax_scale=scale, + causal=True, + return_softmax=False, ) + + # attention_ops.multi_query_cached_kv_attention( + # cu_query_lens, + # output, + # query, + # key_cache, + # value_cache, + # scale, + # block_tables, + # context_len_tensor, + # block_size, + # max_context_len, + # ) benchmark('Ours', run_ours) # Upper bound: Flash attention. diff --git a/csrc/cache.cpp b/csrc/cache.cpp index 907736a981c9..ca9b932e9455 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.cpp @@ -20,6 +20,14 @@ void reshape_and_cache( torch::Tensor& value_cache, torch::Tensor& slot_mapping); +void gather_cached_kv( + torch::Tensor& qkv_out, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& cu_seqlens_k, + torch::Tensor& seqlens_k, + torch::Tensor& block_tables); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "swap_blocks", @@ -33,4 +41,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + m.def( + "gather_cached_kv", + &gather_cached_kv, + "Gather key and value from the cache into contiguous QKV tensors"); } diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3a34ba578980..9cd47b41c835 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -176,6 +176,183 @@ __global__ void reshape_and_cache_kernel( } } +// Grid: (num_blocks, num_heads). +template +__global__ void gather_cached_kv_kernel( + scalar_t* __restrict__ out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow + const int num_seqs, + const int max_num_blocks_per_seq, + const int head_size, + const int block_size) { + // Each CUDA gird is mapped to (num_blocks, num_heads). + const int block_idx = blockIdx.x; + const int num_blocks = gridDim.x; + const int head_idx = blockIdx.y; + const int num_heads = gridDim.y; + // Each CUDA block is responsible for (head_size, block_size). + const int thread_idx = threadIdx.x; + const int num_threads = blockDim.x; + // in the original attention kernel, each thread group fetch x elements at a time. + constexpr int x = 16 / sizeof(scalar_t); + + // the index of the sequence this thread is working on. + int seq_idx; + // the index of the block in the sequence this thread is working on. + int local_block_idx; + // calculate the sequence index and block index in the sequence. + int num_total_blocks = 0; +#pragma unroll + for (int i = 0; i < num_seqs; ++i) { + int context_len = cu_seqlens_k[i + 1] - cu_seqlens_k[i]; + int num_blocks = (context_len + block_size - 1) / block_size; + num_total_blocks += num_blocks; + if (num_total_blocks > block_idx) { + seq_idx = i; + local_block_idx = block_idx - (num_total_blocks - num_blocks); + break; + } + } + // const int context_len = cu_seqlens_k[seq_idx]; + // const int num_blocks = (context_len + block_size - 1) / block_size; + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int physical_block_number = block_table[local_block_idx]; + + // number of chunks handled by a CUDA block. + const int n_chunks = (head_size * block_size + (num_threads - 1)) / num_threads; + const int physical_cache_offset = (physical_block_number * num_heads + head_idx) * head_size * block_size; + + // The common output pointer base used by both key and value: + scalar_t* common_out = out + (block_idx * block_size) * 3 * num_heads * head_size + + head_idx * head_size; + // key is the second tensor in QKV, so qkv_offset = 1 + scalar_t* key_out = common_out + 1 * num_heads * head_size; + // value is the third tensor in QKV, so qkv_offset = 2 + scalar_t* value_out = common_out + 2 * num_heads * head_size; + + // process key in chunks +#pragma unroll + for (int chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) { + const int offset = chunk_idx * num_threads + thread_idx; + if (offset >= head_size * block_size) { + break; + } + // calculate offsets in [head_size/x, block_size, x] + const int head_offset = offset / x / block_size; + const int block_offset = offset / x % block_size; + const int x_offset = offset % x; + + const scalar_t* k_ptr = k_cache + physical_cache_offset + offset; + scalar_t* out_ptr = key_out + block_offset * 3 * num_heads * head_size + + head_offset * x + x_offset; + *out_ptr = __ldg(k_ptr); + } + + // process value in chunks +#pragma unroll + for (int chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) { + const int offset = chunk_idx * num_threads + thread_idx; + if (offset >= head_size * block_size) { + break; + } + // calculate offsets in [head_size, block_size] + const int head_offset = offset / block_size; + const int block_offset = offset % block_size; + + const scalar_t* v_ptr = v_cache + physical_cache_offset + offset; + scalar_t* out_ptr = value_out + block_offset * 3 * num_heads * head_size + head_offset; + *out_ptr = __ldg(v_ptr); + } +} + + +// Grid: (num_blocks, block_size). +template +__global__ void gather_cached_kv_kernel_2( + scalar_t* __restrict__ out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow + const int num_seqs, + const int max_num_blocks_per_seq, + const int num_heads, + const int head_size) { + // Each CUDA gird is mapped to (num_blocks, num_heads). + const int block_idx = blockIdx.x; + const int num_blocks = gridDim.x; + const int block_offset = blockIdx.y; + const int block_size = gridDim.y; + // Each CUDA block is responsible for (head_size, block_size). + const int thread_idx = threadIdx.x; + const int num_threads = blockDim.x; + // in the original attention kernel, each thread group fetch x elements at a time. + constexpr int x = 16 / sizeof(scalar_t); + + // the index of the sequence this thread is working on. + int seq_idx; + // the index of the block in the sequence this thread is working on. + int local_block_idx; + // calculate the sequence index and block index in the sequence. + int num_total_blocks = 0; +#pragma unroll + for (int i = 0; i < num_seqs; ++i) { + int context_len = cu_seqlens_k[i + 1] - cu_seqlens_k[i]; + int num_blocks = (context_len + block_size - 1) / block_size; + num_total_blocks += num_blocks; + if (num_total_blocks > block_idx) { + seq_idx = i; + local_block_idx = block_idx - (num_total_blocks - num_blocks); + break; + } + } + + // const int context_len = cu_seqlens_k[seq_idx]; + // const int num_blocks = (context_len + block_size - 1) / block_size; + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int physical_block_number = block_table[local_block_idx]; + const int physical_cache_offset = physical_block_number * num_heads * head_size * block_size; + + // The common output pointer base used by both key and value: + scalar_t* common_out = out + (block_idx * block_size + block_offset) * 3 * num_heads * head_size; + // key is the second tensor in QKV, so qkv_offset = 1 + scalar_t* key_out = common_out + 1 * num_heads * head_size; + // value is the third tensor in QKV, so qkv_offset = 2 + scalar_t* value_out = common_out + 2 * num_heads * head_size; + + // process key in chunks +#pragma unroll + for (int i = threadIdx.x; i < num_heads * head_size; i += blockDim.x) { + // calculate offsets in [num_heads, head_size/x, x] + const int head_idx = i / x / (head_size / x); + const int head_offset = i / x % (head_size / x); + const int x_offset = i % x; + + const scalar_t* k_ptr = k_cache + physical_cache_offset + + head_idx * (head_size/x) * block_size * x + + head_offset * block_size * x + + block_offset * x + + x_offset; + key_out[head_idx * head_size + head_offset * x + x_offset] = __ldg(k_ptr); + } + + // process value in chunks +#pragma unroll + for (int i = threadIdx.x; i < num_heads * head_size; i += blockDim.x) { + // calculate offsets in [num_heads, head_size] + const int head_idx = i / head_size; + const int head_offset = i % head_size; + + const scalar_t* v_ptr = v_cache + physical_cache_offset + + i * block_size // equal to (head_idx * head_size + head_offset) * block_size + + block_offset; + value_out[i] = __ldg(v_ptr); + } +} + } // namespace cacheflow void reshape_and_cache( @@ -215,3 +392,96 @@ void reshape_and_cache( x); }); } + +/* +// same group of threads will be working on the same block +void gather_cached_kv( + torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow + torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k' + torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq] + const int num_seqs = cu_seqlens_k.size(0) - 1; + const int num_heads = value_cache.size(1); + const int head_size = value_cache.size(2); + const int block_size = value_cache.size(3); + // const int x = key_cache.size(4); + const int max_num_blocks_per_seq = block_tables.size(1); + const int* context_lens_ptr = cu_seqlens_k.data_ptr(); + const int* cpu_context_lens_ptr = seqlens_k.data_ptr(); + + // calculate the total amount of blocks + int num_total_blocks = 0; + for (int i = 0; i < num_seqs; ++i) { + int context_len = cpu_context_lens_ptr[i + 1] - cpu_context_lens_ptr[i]; + int num_blocks = (context_len + block_size - 1) / block_size; + num_total_blocks += num_blocks; + } + + constexpr int NUM_THREADS = 256; + dim3 grid(num_total_blocks, num_heads); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + key_cache.scalar_type(), + "gather_cached_kv_kernel", + [&] { + cacheflow::gather_cached_kv_kernel<<>>( + qkv_out.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + block_tables.data_ptr(), + cu_seqlens_k.data_ptr(), + num_seqs, + max_num_blocks_per_seq, + head_size, + block_size); + }); +} +*/ + +// same group of threads will be working on the same block +void gather_cached_kv( + torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow + torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k' + torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq] + const int num_seqs = cu_seqlens_k.size(0) - 1; + const int num_heads = value_cache.size(1); + const int head_size = value_cache.size(2); + const int block_size = value_cache.size(3); + // const int x = key_cache.size(4); + const int max_num_blocks_per_seq = block_tables.size(1); + const int* context_lens_ptr = cu_seqlens_k.data_ptr(); + const int* cpu_context_lens_ptr = seqlens_k.data_ptr(); + + // calculate the total amount of blocks + int num_total_blocks = 0; + for (int i = 0; i < num_seqs; ++i) { + int context_len = cpu_context_lens_ptr[i + 1] - cpu_context_lens_ptr[i]; + int num_blocks = (context_len + block_size - 1) / block_size; + num_total_blocks += num_blocks; + } + + dim3 grid(num_total_blocks, block_size); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + key_cache.scalar_type(), + "gather_cached_kv_kernel_2", + [&] { + cacheflow::gather_cached_kv_kernel_2<<>>( + qkv_out.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + block_tables.data_ptr(), + cu_seqlens_k.data_ptr(), + num_seqs, + max_num_blocks_per_seq, + num_heads, + head_size); + }); +} diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index a66f2c3daca7..e63d7f814be3 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -4,7 +4,7 @@ from flash_attn.flash_attn_interface import _flash_attn_forward import torch -from cacheflow import attention_ops +from cacheflow import attention_ops, cache_ops MAX_SEQ_LEN = 4096 @@ -65,7 +65,8 @@ def ref_single_query_cached_kv_attention( def ref_multi_query_kv_attention( - cu_seq_lens: List[int], + cu_query_lens: List[int], + cu_context_lens: List[int], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -74,21 +75,25 @@ def ref_multi_query_kv_attention( head_size = query.shape[-1] scale = 1.0 / (head_size ** 0.5) - num_seqs = len(cu_seq_lens) - 1 + num_seqs = len(cu_query_lens) - 1 ref_outputs = [] for i in range(num_seqs): - start_idx = cu_seq_lens[i] - end_idx = cu_seq_lens[i + 1] - seq_len = end_idx - start_idx + query_start_idx = cu_query_lens[i] + query_end_idx = cu_query_lens[i + 1] + query_len = query_end_idx - query_start_idx + + context_start_idx = cu_context_lens[i] + context_end_idx = cu_context_lens[i + 1] + context_len = context_end_idx - context_start_idx # Create attention mask - attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5 + attn_mask = torch.triu(torch.ones(query_len, context_len), diagonal=1) * -1e5 attn_mask = attn_mask.to(dtype=dtype, device='cuda') ref_output = ref_masked_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], + query[query_start_idx:query_end_idx], + key[context_start_idx:context_end_idx], + value[context_start_idx:context_end_idx], scale, attn_mask=attn_mask, ) @@ -227,42 +232,58 @@ def test_multi_query_kv_attention( head_size: int, dtype: torch.dtype, ) -> None: - seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) - max_seq_len = max(seq_lens) - num_tokens = sum(seq_lens) + query_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) + max_query_len = max(query_lens) + num_query_tokens = sum(query_lens) + cu_query_lens = [0] + for seq_len in query_lens: + cu_query_lens.append(cu_query_lens[-1] + seq_len) + cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda') - cu_seq_lens = [0] - for seq_len in seq_lens: - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda') + context_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) + max_context_len = max(context_lens) + num_context_tokens = sum(context_lens) + cu_context_lens = [0] + for seq_len in context_lens: + cu_context_lens.append(cu_context_lens[-1] + seq_len) + cu_context_lens = torch.tensor( + cu_context_lens, dtype=torch.int, device='cuda') scale = float(1.0 / (head_size ** 0.5)) qkv = torch.randn( - num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + num_query_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + query, _, _ = qkv.unbind(dim=1) + qkv = torch.randn( + num_context_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + _, key, value = qkv.unbind(dim=1) + # Adjust the range of the values to reduce precision errors. - qkv = qkv / (head_size ** 0.5) + query = query / (head_size ** 0.5) + key = key / (head_size ** 0.5) + value = value / (head_size ** 0.5) - query, key, value = qkv.unbind(dim=1) output = torch.empty( - num_tokens, num_heads, head_size, dtype=dtype, device='cuda') + num_query_tokens, num_heads, head_size, dtype=dtype, device='cuda') _flash_attn_forward( query, key, value, output, - cu_seq_lens, - cu_seq_lens, - max_seq_len, - max_seq_len, + cu_query_lens, + cu_context_lens, + max_query_len, + max_context_len, dropout_p=0.0, softmax_scale=scale, causal=True, return_softmax=False, ) - cu_seq_lens = cu_seq_lens.cpu().tolist() + cu_query_lens = cu_query_lens.cpu().tolist() + cu_context_lens = cu_context_lens.cpu().tolist() ref_output = ref_multi_query_kv_attention( - cu_seq_lens, + cu_query_lens, + cu_context_lens, query, key, value, @@ -287,7 +308,7 @@ def test_multi_query_cached_kv_attention( qkv = torch.randn( num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') - query, _, _ = qkv.unbind(dim=1) + query, key, value = qkv.unbind(dim=1) x = 16 // torch.tensor([], dtype=dtype).element_size() key_block_shape = (num_heads, head_size // x, block_size, x) key_cache = torch.randn( @@ -302,7 +323,13 @@ def test_multi_query_cached_kv_attention( for query_len in query_lens ] max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') + cu_context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') + + cu_seqlens_k = [0] + for seq_len in context_lens: + cu_seqlens_k.append(cu_seqlens_k[-1] + seq_len) + cu_seqlens_k = torch.tensor( + cu_seqlens_k, dtype=torch.int, device='cuda') max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] @@ -318,17 +345,49 @@ def test_multi_query_cached_kv_attention( output = torch.empty( num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda') - attention_ops.multi_query_cached_kv_attention( - cu_query_lens, - output, - query, + # attention_ops.multi_query_cached_kv_attention( + # cu_query_lens, + # output, + # query, + # key_cache, + # value_cache, + # scale, + # block_tables, + # context_lens, + # block_size, + # max_context_len, + # ) + + old_key = key.clone() + old_value = value.clone() + + cache_ops.gather_cached_kv( + qkv, key_cache, value_cache, - scale, + cu_seqlens_k, + cu_seqlens_k.cpu(), block_tables, - context_lens, - block_size, + ) + + # test if key and value are updated + assert not torch.allclose(key, old_key, atol=1e-3, rtol=1e-5) + assert not torch.allclose(value, old_value, atol=1e-3, rtol=1e-5) + + + _flash_attn_forward( + query, + key, + value, + output, + cu_query_lens, + cu_seqlens_k, + num_total_tokens, max_context_len, + dropout_p=0.0, + softmax_scale=scale, + causal=True, + return_softmax=False, ) ref_output = ref_multi_query_cached_kv_attention( @@ -337,7 +396,7 @@ def test_multi_query_cached_kv_attention( key_cache, value_cache, block_tables, - context_lens, + cu_context_lens, dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) @@ -349,20 +408,20 @@ def test_attention(seed: int) -> None: # the test fails due to the precision issue. Re-run the test if it fails. torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - for dtype in [torch.half, torch.float]: - for block_size in [8, 16]: - for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: - print(f'Testing single_query_cached_kv_attention with ' - f'dtype={dtype}, block_size={block_size}, ' - f'head_size={head_size}') - test_single_query_cached_kv_attention( - num_tokens=37, - num_heads=3, - head_size=head_size, - block_size=block_size, - num_blocks=1024, - dtype=dtype, - ) + # for dtype in [torch.half, torch.float]: + # for block_size in [8, 16]: + # for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: + # print(f'Testing single_query_cached_kv_attention with ' + # f'dtype={dtype}, block_size={block_size}, ' + # f'head_size={head_size}') + # test_single_query_cached_kv_attention( + # num_tokens=37, + # num_heads=3, + # head_size=head_size, + # block_size=block_size, + # num_blocks=1024, + # dtype=dtype, + # ) # NOTE(siyuan): Same as above. Re-run the test if it fails. Also # note that the test is also more likely to fail due to the much From 82fc4f42d48726eb7aafc86d918d2b00cc6a3b9a Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 01:49:30 -0700 Subject: [PATCH 02/16] update --- csrc/cache_kernels.cu | 230 +++++++++++++++++++++--------------------- 1 file changed, 115 insertions(+), 115 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 9cd47b41c835..b744d8634c73 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -272,84 +272,44 @@ __global__ void gather_cached_kv_kernel( // Grid: (num_blocks, block_size). template __global__ void gather_cached_kv_kernel_2( - scalar_t* __restrict__ out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] + scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] + scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow - const int num_seqs, - const int max_num_blocks_per_seq, + const int* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, const int num_heads, - const int head_size) { - // Each CUDA gird is mapped to (num_blocks, num_heads). - const int block_idx = blockIdx.x; - const int num_blocks = gridDim.x; - const int block_offset = blockIdx.y; - const int block_size = gridDim.y; - // Each CUDA block is responsible for (head_size, block_size). - const int thread_idx = threadIdx.x; - const int num_threads = blockDim.x; - // in the original attention kernel, each thread group fetch x elements at a time. - constexpr int x = 16 / sizeof(scalar_t); - - // the index of the sequence this thread is working on. - int seq_idx; - // the index of the block in the sequence this thread is working on. - int local_block_idx; - // calculate the sequence index and block index in the sequence. - int num_total_blocks = 0; -#pragma unroll - for (int i = 0; i < num_seqs; ++i) { - int context_len = cu_seqlens_k[i + 1] - cu_seqlens_k[i]; - int num_blocks = (context_len + block_size - 1) / block_size; - num_total_blocks += num_blocks; - if (num_total_blocks > block_idx) { - seq_idx = i; - local_block_idx = block_idx - (num_total_blocks - num_blocks); - break; - } - } - - // const int context_len = cu_seqlens_k[seq_idx]; - // const int num_blocks = (context_len + block_size - 1) / block_size; - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - const int physical_block_number = block_table[local_block_idx]; - const int physical_cache_offset = physical_block_number * num_heads * head_size * block_size; - - // The common output pointer base used by both key and value: - scalar_t* common_out = out + (block_idx * block_size + block_offset) * 3 * num_heads * head_size; - // key is the second tensor in QKV, so qkv_offset = 1 - scalar_t* key_out = common_out + 1 * num_heads * head_size; - // value is the third tensor in QKV, so qkv_offset = 2 - scalar_t* value_out = common_out + 2 * num_heads * head_size; - - // process key in chunks -#pragma unroll - for (int i = threadIdx.x; i < num_heads * head_size; i += blockDim.x) { - // calculate offsets in [num_heads, head_size/x, x] - const int head_idx = i / x / (head_size / x); - const int head_offset = i / x % (head_size / x); - const int x_offset = i % x; - - const scalar_t* k_ptr = k_cache + physical_cache_offset - + head_idx * (head_size/x) * block_size * x - + head_offset * block_size * x - + block_offset * x - + x_offset; - key_out[head_idx * head_size + head_offset * x + x_offset] = __ldg(k_ptr); - } - - // process value in chunks -#pragma unroll - for (int i = threadIdx.x; i < num_heads * head_size; i += blockDim.x) { - // calculate offsets in [num_heads, head_size] + const int head_size, + const int block_size, + const int x) { + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int num_tokens = num_heads * head_size; + for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { + const int tgt_key_idx = token_idx * key_stride + i; + const int tgt_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; const int head_offset = i % head_size; + const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension + const int x_offset = head_offset % x; + + const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int src_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; - const scalar_t* v_ptr = v_cache + physical_cache_offset - + i * block_size // equal to (head_idx * head_size + head_offset) * block_size - + block_offset; - value_out[i] = __ldg(v_ptr); + key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); + value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); } } @@ -393,6 +353,45 @@ void reshape_and_cache( }); } + +void gather_cached_kv( + torch::Tensor& key, // [out] [num_tokens, num_heads, head_size] + torch::Tensor& value, // [out] [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping) // [in] [num_tokens] +{ + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + key.scalar_type(), + "gather_cached_kv_kernel_2", + [&] { + cacheflow::gather_cached_kv_kernel_2<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size, + x); + }); +} + /* // same group of threads will be working on the same block void gather_cached_kv( @@ -441,47 +440,48 @@ void gather_cached_kv( } */ -// same group of threads will be working on the same block -void gather_cached_kv( - torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow - torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k' - torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq] - const int num_seqs = cu_seqlens_k.size(0) - 1; - const int num_heads = value_cache.size(1); - const int head_size = value_cache.size(2); - const int block_size = value_cache.size(3); - // const int x = key_cache.size(4); - const int max_num_blocks_per_seq = block_tables.size(1); - const int* context_lens_ptr = cu_seqlens_k.data_ptr(); - const int* cpu_context_lens_ptr = seqlens_k.data_ptr(); - - // calculate the total amount of blocks - int num_total_blocks = 0; - for (int i = 0; i < num_seqs; ++i) { - int context_len = cpu_context_lens_ptr[i + 1] - cpu_context_lens_ptr[i]; - int num_blocks = (context_len + block_size - 1) / block_size; - num_total_blocks += num_blocks; - } - dim3 grid(num_total_blocks, block_size); - dim3 block(std::min(num_heads * head_size, 512)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - key_cache.scalar_type(), - "gather_cached_kv_kernel_2", - [&] { - cacheflow::gather_cached_kv_kernel_2<<>>( - qkv_out.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - block_tables.data_ptr(), - cu_seqlens_k.data_ptr(), - num_seqs, - max_num_blocks_per_seq, - num_heads, - head_size); - }); -} +// // same group of threads will be working on the same block +// void gather_cached_kv( +// torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] +// torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] +// torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] +// torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow +// torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k' +// torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq] +// const int num_seqs = cu_seqlens_k.size(0) - 1; +// const int num_heads = value_cache.size(1); +// const int head_size = value_cache.size(2); +// const int block_size = value_cache.size(3); +// // const int x = key_cache.size(4); +// const int max_num_blocks_per_seq = block_tables.size(1); +// const int* context_lens_ptr = cu_seqlens_k.data_ptr(); +// const int* cpu_context_lens_ptr = seqlens_k.data_ptr(); + +// // calculate the total amount of blocks +// int num_total_blocks = 0; +// for (int i = 0; i < num_seqs; ++i) { +// int context_len = cpu_context_lens_ptr[i + 1] - cpu_context_lens_ptr[i]; +// int num_blocks = (context_len + block_size - 1) / block_size; +// num_total_blocks += num_blocks; +// } + +// dim3 grid(num_total_blocks, block_size); +// dim3 block(std::min(num_heads * head_size, 512)); +// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +// AT_DISPATCH_FLOATING_TYPES_AND_HALF( +// key_cache.scalar_type(), +// "gather_cached_kv_kernel_2", +// [&] { +// cacheflow::gather_cached_kv_kernel_2<<>>( +// qkv_out.data_ptr(), +// key_cache.data_ptr(), +// value_cache.data_ptr(), +// block_tables.data_ptr(), +// cu_seqlens_k.data_ptr(), +// num_seqs, +// max_num_blocks_per_seq, +// num_heads, +// head_size); +// }); +// } From 6057f9fe3aad64121057a637012c92f54f1d5f7b Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 01:54:48 -0700 Subject: [PATCH 03/16] update test --- tests/kernels/cache.py | 44 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/kernels/cache.py b/tests/kernels/cache.py index 89f14cca82a2..f444ac16a49d 100644 --- a/tests/kernels/cache.py +++ b/tests/kernels/cache.py @@ -99,6 +99,47 @@ def test_reshape_and_cache( assert torch.allclose(value_cache, cloned_value_cache) +def test_gather_cached_kv( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, +) -> None: + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') + + qkv = torch.randn( + num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + _, key, value = qkv.unbind(dim=1) + + qkv_clone = qkv.clone() + _, cloned_key, cloned_value = qkv_clone.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_cache = torch.randn( + size=value_cache_shape, dtype=dtype, device='cuda') + + cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) + + # Reference implementation. + for i in range(num_tokens): + reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x) + block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :] + cloned_value[i] = value_cache[block_idx, :, :, block_offset] + + assert torch.allclose(key, cloned_key) + assert torch.allclose(value, cloned_value) + + @torch.inference_mode() def test_cache() -> None: test_copy_blocks( @@ -107,6 +148,9 @@ def test_cache() -> None: test_reshape_and_cache( num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, dtype=torch.half) + test_gather_cached_kv( + num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, + dtype=torch.half) if __name__ == '__main__': From 075b48a8e1fea7ccf78271ab1edf4d8f5267b0fa Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 01:57:06 -0700 Subject: [PATCH 04/16] update --- csrc/cache_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index b744d8634c73..bc0091dc6ffa 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -274,8 +274,8 @@ template __global__ void gather_cached_kv_kernel_2( scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] + const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, From a6157508a35fc6d444688a24b99a0274f5c3087f Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 01:59:37 -0700 Subject: [PATCH 05/16] update API --- csrc/cache.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/cache.cpp b/csrc/cache.cpp index ca9b932e9455..9ae17bb2985c 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.cpp @@ -21,12 +21,11 @@ void reshape_and_cache( torch::Tensor& slot_mapping); void gather_cached_kv( - torch::Tensor& qkv_out, + torch::Tensor& key, + torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& cu_seqlens_k, - torch::Tensor& seqlens_k, - torch::Tensor& block_tables); + torch::Tensor& slot_mapping); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( From 3f3991d08f3ef394ac863d6a94550111f9e57612 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 02:08:43 -0700 Subject: [PATCH 06/16] update --- csrc/cache_kernels.cu | 187 +----------------------------------------- 1 file changed, 1 insertion(+), 186 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index bc0091dc6ffa..93e4d25e1a14 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -176,98 +176,6 @@ __global__ void reshape_and_cache_kernel( } } -// Grid: (num_blocks, num_heads). -template -__global__ void gather_cached_kv_kernel( - scalar_t* __restrict__ out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow - const int num_seqs, - const int max_num_blocks_per_seq, - const int head_size, - const int block_size) { - // Each CUDA gird is mapped to (num_blocks, num_heads). - const int block_idx = blockIdx.x; - const int num_blocks = gridDim.x; - const int head_idx = blockIdx.y; - const int num_heads = gridDim.y; - // Each CUDA block is responsible for (head_size, block_size). - const int thread_idx = threadIdx.x; - const int num_threads = blockDim.x; - // in the original attention kernel, each thread group fetch x elements at a time. - constexpr int x = 16 / sizeof(scalar_t); - - // the index of the sequence this thread is working on. - int seq_idx; - // the index of the block in the sequence this thread is working on. - int local_block_idx; - // calculate the sequence index and block index in the sequence. - int num_total_blocks = 0; -#pragma unroll - for (int i = 0; i < num_seqs; ++i) { - int context_len = cu_seqlens_k[i + 1] - cu_seqlens_k[i]; - int num_blocks = (context_len + block_size - 1) / block_size; - num_total_blocks += num_blocks; - if (num_total_blocks > block_idx) { - seq_idx = i; - local_block_idx = block_idx - (num_total_blocks - num_blocks); - break; - } - } - // const int context_len = cu_seqlens_k[seq_idx]; - // const int num_blocks = (context_len + block_size - 1) / block_size; - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - const int physical_block_number = block_table[local_block_idx]; - - // number of chunks handled by a CUDA block. - const int n_chunks = (head_size * block_size + (num_threads - 1)) / num_threads; - const int physical_cache_offset = (physical_block_number * num_heads + head_idx) * head_size * block_size; - - // The common output pointer base used by both key and value: - scalar_t* common_out = out + (block_idx * block_size) * 3 * num_heads * head_size - + head_idx * head_size; - // key is the second tensor in QKV, so qkv_offset = 1 - scalar_t* key_out = common_out + 1 * num_heads * head_size; - // value is the third tensor in QKV, so qkv_offset = 2 - scalar_t* value_out = common_out + 2 * num_heads * head_size; - - // process key in chunks -#pragma unroll - for (int chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) { - const int offset = chunk_idx * num_threads + thread_idx; - if (offset >= head_size * block_size) { - break; - } - // calculate offsets in [head_size/x, block_size, x] - const int head_offset = offset / x / block_size; - const int block_offset = offset / x % block_size; - const int x_offset = offset % x; - - const scalar_t* k_ptr = k_cache + physical_cache_offset + offset; - scalar_t* out_ptr = key_out + block_offset * 3 * num_heads * head_size - + head_offset * x + x_offset; - *out_ptr = __ldg(k_ptr); - } - - // process value in chunks -#pragma unroll - for (int chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) { - const int offset = chunk_idx * num_threads + thread_idx; - if (offset >= head_size * block_size) { - break; - } - // calculate offsets in [head_size, block_size] - const int head_offset = offset / block_size; - const int block_offset = offset % block_size; - - const scalar_t* v_ptr = v_cache + physical_cache_offset + offset; - scalar_t* out_ptr = value_out + block_offset * 3 * num_heads * head_size + head_offset; - *out_ptr = __ldg(v_ptr); - } -} - // Grid: (num_blocks, block_size). template @@ -391,97 +299,4 @@ void gather_cached_kv( x); }); } - -/* -// same group of threads will be working on the same block -void gather_cached_kv( - torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow - torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k' - torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq] - const int num_seqs = cu_seqlens_k.size(0) - 1; - const int num_heads = value_cache.size(1); - const int head_size = value_cache.size(2); - const int block_size = value_cache.size(3); - // const int x = key_cache.size(4); - const int max_num_blocks_per_seq = block_tables.size(1); - const int* context_lens_ptr = cu_seqlens_k.data_ptr(); - const int* cpu_context_lens_ptr = seqlens_k.data_ptr(); - - // calculate the total amount of blocks - int num_total_blocks = 0; - for (int i = 0; i < num_seqs; ++i) { - int context_len = cpu_context_lens_ptr[i + 1] - cpu_context_lens_ptr[i]; - int num_blocks = (context_len + block_size - 1) / block_size; - num_total_blocks += num_blocks; - } - - constexpr int NUM_THREADS = 256; - dim3 grid(num_total_blocks, num_heads); - dim3 block(NUM_THREADS); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - key_cache.scalar_type(), - "gather_cached_kv_kernel", - [&] { - cacheflow::gather_cached_kv_kernel<<>>( - qkv_out.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - block_tables.data_ptr(), - cu_seqlens_k.data_ptr(), - num_seqs, - max_num_blocks_per_seq, - head_size, - block_size); - }); -} -*/ - - -// // same group of threads will be working on the same block -// void gather_cached_kv( -// torch::Tensor& qkv_out, // [cu_seqlens_k[-1], 3(QKV), num_heads, head_size] -// torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] -// torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] -// torch::Tensor& cu_seqlens_k, // aka 'cu_seqlens_k' in '_flash_attn_forward', or 'context_lens' in cacheflow -// torch::Tensor& seqlens_k, // CPU version of 'cu_seqlens_k' -// torch::Tensor& block_tables) { // [num_seqs, max_num_blocks_per_seq] -// const int num_seqs = cu_seqlens_k.size(0) - 1; -// const int num_heads = value_cache.size(1); -// const int head_size = value_cache.size(2); -// const int block_size = value_cache.size(3); -// // const int x = key_cache.size(4); -// const int max_num_blocks_per_seq = block_tables.size(1); -// const int* context_lens_ptr = cu_seqlens_k.data_ptr(); -// const int* cpu_context_lens_ptr = seqlens_k.data_ptr(); - -// // calculate the total amount of blocks -// int num_total_blocks = 0; -// for (int i = 0; i < num_seqs; ++i) { -// int context_len = cpu_context_lens_ptr[i + 1] - cpu_context_lens_ptr[i]; -// int num_blocks = (context_len + block_size - 1) / block_size; -// num_total_blocks += num_blocks; -// } - -// dim3 grid(num_total_blocks, block_size); -// dim3 block(std::min(num_heads * head_size, 512)); -// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); -// AT_DISPATCH_FLOATING_TYPES_AND_HALF( -// key_cache.scalar_type(), -// "gather_cached_kv_kernel_2", -// [&] { -// cacheflow::gather_cached_kv_kernel_2<<>>( -// qkv_out.data_ptr(), -// key_cache.data_ptr(), -// value_cache.data_ptr(), -// block_tables.data_ptr(), -// cu_seqlens_k.data_ptr(), -// num_seqs, -// max_num_blocks_per_seq, -// num_heads, -// head_size); -// }); -// } + \ No newline at end of file From 288240c8a3ec7b0b31d1962f084d68e4254a5e3f Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 02:11:00 -0700 Subject: [PATCH 07/16] update --- csrc/cache_kernels.cu | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 93e4d25e1a14..0dc09c065c50 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -176,6 +176,49 @@ __global__ void reshape_and_cache_kernel( } } +// Grid: (num_blocks, block_size). +template +__global__ void gather_cached_kv_kernel( + scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] + scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size] + const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const int x) { + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int num_tokens = num_heads * head_size; + for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { + const int tgt_key_idx = token_idx * key_stride + i; + const int tgt_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension + const int x_offset = head_offset % x; + + const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int src_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; + + key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); + value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); + } +} // Grid: (num_blocks, block_size). template From 127027d8377a26cc7f2fae75d50ec5e55313c522 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 02:25:23 -0700 Subject: [PATCH 08/16] cleanup --- csrc/cache_kernels.cu | 48 ++----------------------------------------- 1 file changed, 2 insertions(+), 46 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0dc09c065c50..5a59710c44ff 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -220,50 +220,6 @@ __global__ void gather_cached_kv_kernel( } } -// Grid: (num_blocks, block_size). -template -__global__ void gather_cached_kv_kernel_2( - scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] - scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size] - const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) { - const int token_idx = blockIdx.x; - const int slot_idx = slot_mapping[token_idx]; - const int block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; - - const int num_tokens = num_heads * head_size; - for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { - const int tgt_key_idx = token_idx * key_stride + i; - const int tgt_value_idx = token_idx * value_stride + i; - - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension - const int x_offset = head_offset % x; - - const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int src_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; - - key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); - } -} - } // namespace cacheflow void reshape_and_cache( @@ -326,9 +282,9 @@ void gather_cached_kv( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( key.scalar_type(), - "gather_cached_kv_kernel_2", + "gather_cached_kv_kernel", [&] { - cacheflow::gather_cached_kv_kernel_2<<>>( + cacheflow::gather_cached_kv_kernel<<>>( key.data_ptr(), value.data_ptr(), key_cache.data_ptr(), From 096c04bc8842a84e762af2ca51780feb585d8f7e Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 02:34:06 -0700 Subject: [PATCH 09/16] add benchmark --- benchmark/benchmark_cache.py | 78 ++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 benchmark/benchmark_cache.py diff --git a/benchmark/benchmark_cache.py b/benchmark/benchmark_cache.py new file mode 100644 index 000000000000..285722bb3a0c --- /dev/null +++ b/benchmark/benchmark_cache.py @@ -0,0 +1,78 @@ +import functools +import random +import time + +import torch + +from cacheflow import cache_ops + + +def benchmark(name, f, num_warmup = 10, num_iters = 100): + for _ in range(num_warmup): + f() + torch.cuda.synchronize() + + start = time.time() + for _ in range(num_iters): + f() + torch.cuda.synchronize() + end = time.time() + print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms') + + +@torch.inference_mode() +def test_gather_cached_kv( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, +) -> None: + print(f'num_tokens: {num_tokens}, num_heads: {num_heads}, ' + f'head_size: {head_size}, block_size: {block_size}, ' + f'num_blocks: {num_blocks}, dtype: {dtype}') + + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') + + qkv = torch.randn( + num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + _, key, value = qkv.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_cache = torch.randn( + size=value_cache_shape, dtype=dtype, device='cuda') + + # Run Flash attention. + def run_flash_attn(): + cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) + + benchmark('gather_cached_kv', run_flash_attn) + + +if __name__ == '__main__': + BLOCK_SIZE = 8 + NUM_BLOCKS = 1024 + DTYPE = torch.half + + # LLaMA-13B and OPT-13B + NUM_HEADS = 40 + HEAD_SIZE = 128 + + run_benchmark = functools.partial( + test_gather_cached_kv, + num_heads=NUM_HEADS, + head_size=HEAD_SIZE, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + dtype=DTYPE, + ) + + for i in range(6, 12): + run_benchmark(num_tokens=i ** 2) From b8ac6492d28c3da95792997ef8e93c513e85b300 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 02:41:38 -0700 Subject: [PATCH 10/16] update --- benchmark/benchmark_cache.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/benchmark/benchmark_cache.py b/benchmark/benchmark_cache.py index 285722bb3a0c..2ce33ecac9e6 100644 --- a/benchmark/benchmark_cache.py +++ b/benchmark/benchmark_cache.py @@ -7,7 +7,7 @@ from cacheflow import cache_ops -def benchmark(name, f, num_warmup = 10, num_iters = 100): +def benchmark(name, f, size: int, num_warmup = 10, num_iters = 100): for _ in range(num_warmup): f() torch.cuda.synchronize() @@ -17,7 +17,9 @@ def benchmark(name, f, num_warmup = 10, num_iters = 100): f() torch.cuda.synchronize() end = time.time() - print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms') + avg_time = (end - start) / num_iters + print(f'[Latency] {name}: {avg_time * 1000:.3f} ms') + print(f'[Throughput] {name}: {size / avg_time / 2 ** 30:.3f} GB/s') @torch.inference_mode() @@ -50,10 +52,11 @@ def test_gather_cached_kv( size=value_cache_shape, dtype=dtype, device='cuda') # Run Flash attention. - def run_flash_attn(): + def run(): cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) - benchmark('gather_cached_kv', run_flash_attn) + benchmark('gather_cached_kv', run, + size=block_size * num_blocks * num_heads * head_size * 2 * dtype.element_size()) if __name__ == '__main__': @@ -75,4 +78,4 @@ def run_flash_attn(): ) for i in range(6, 12): - run_benchmark(num_tokens=i ** 2) + run_benchmark(num_tokens=2 ** i) From 278c4b0e67f1102ec3e7528b8d585b17d6948ca2 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 02:45:20 -0700 Subject: [PATCH 11/16] fix --- benchmark/benchmark_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/benchmark_cache.py b/benchmark/benchmark_cache.py index 2ce33ecac9e6..23712633883e 100644 --- a/benchmark/benchmark_cache.py +++ b/benchmark/benchmark_cache.py @@ -56,7 +56,7 @@ def run(): cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) benchmark('gather_cached_kv', run, - size=block_size * num_blocks * num_heads * head_size * 2 * dtype.element_size()) + size=block_size * num_blocks * num_heads * head_size * 2 * qkv.element_size()) if __name__ == '__main__': From 0e09cc84598a211050654d8efaa12b14f7b5f46b Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 02:47:49 -0700 Subject: [PATCH 12/16] fix --- benchmark/benchmark_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/benchmark_cache.py b/benchmark/benchmark_cache.py index 23712633883e..6163d4815980 100644 --- a/benchmark/benchmark_cache.py +++ b/benchmark/benchmark_cache.py @@ -56,7 +56,7 @@ def run(): cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) benchmark('gather_cached_kv', run, - size=block_size * num_blocks * num_heads * head_size * 2 * qkv.element_size()) + size=num_tokens * num_heads * head_size * 2 * qkv.element_size()) if __name__ == '__main__': From 72d70532d7291ec05fb8ea24ee5ca8f333767037 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 02:53:24 -0700 Subject: [PATCH 13/16] optimization --- csrc/cache_kernels.cu | 77 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 2 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 5a59710c44ff..ca01532d8868 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -220,6 +220,79 @@ __global__ void gather_cached_kv_kernel( } } +template +__global__ void gather_cached_kv_kernel_optimized( + scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size] + scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size] + const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int *__restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const int x) +{ + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int num_tokens = num_heads * head_size; + const int unroll_factor = 4; + const int unrolled_num_tokens = num_tokens / unroll_factor; + + for (int i = threadIdx.x; i < unrolled_num_tokens; i += blockDim.x) + { + int tgt_key_indices[unroll_factor]; + int tgt_value_indices[unroll_factor]; + int src_key_indices[unroll_factor]; + int src_value_indices[unroll_factor]; + scalar_t keys_to_store[unroll_factor]; + scalar_t values_to_store[unroll_factor]; + + #pragma unroll + for (int j = 0; j < unroll_factor; ++j) + { + int index = i + j * unrolled_num_tokens; + + const int tgt_key_idx = token_idx * key_stride + index; + const int tgt_value_idx = token_idx * value_stride + index; + + const int head_idx = index / head_size; + const int head_offset = index % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int src_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; + + tgt_key_indices[j] = tgt_key_idx; + tgt_value_indices[j] = tgt_value_idx; + src_key_indices[j] = src_key_idx; + src_value_indices[j] = src_value_idx; + + keys_to_store[j] = __ldg(&key_cache[src_key_idx]); + values_to_store[j] = __ldg(&value_cache[src_value_idx]); + } + + #pragma unroll + for (int j = 0; j < unroll_factor; ++j) + { + key[tgt_key_indices[j]] = keys_to_store[j]; + value[tgt_value_indices[j]] = values_to_store[j]; + } + } +} + } // namespace cacheflow void reshape_and_cache( @@ -282,9 +355,9 @@ void gather_cached_kv( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( key.scalar_type(), - "gather_cached_kv_kernel", + "gather_cached_kv_kernel_optimized", [&] { - cacheflow::gather_cached_kv_kernel<<>>( + cacheflow::gather_cached_kv_kernel_optimized<<>>( key.data_ptr(), value.data_ptr(), key_cache.data_ptr(), From c3a2e87e2a8e6e5fc4602dd8bdefb4300d513177 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 03:11:02 -0700 Subject: [PATCH 14/16] revert changes to benchmarks --- benchmark/benchmark_attention.py | 50 ++-------- tests/kernels/attention.py | 161 ++++++++++--------------------- 2 files changed, 60 insertions(+), 151 deletions(-) diff --git a/benchmark/benchmark_attention.py b/benchmark/benchmark_attention.py index f75a0f6e58a8..ac43ddb36e54 100644 --- a/benchmark/benchmark_attention.py +++ b/benchmark/benchmark_attention.py @@ -6,7 +6,7 @@ from flash_attn.flash_attn_interface import _flash_attn_forward import torch -from cacheflow import attention_ops, cache_ops +from cacheflow import attention_ops def benchmark(name, f, num_warmup = 10, num_iters = 100): @@ -43,7 +43,7 @@ def benchmark_multi_query_cached_kv_attention( num_total_tokens = cu_query_lens[-1] qkv = torch.randn( num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') - query, key, value = qkv.unbind(dim=1) # NOTE: this will not make a copy. + query, _, _ = qkv.unbind(dim=1) # Create key and value cache. x = 16 // torch.tensor([], dtype=dtype).element_size() @@ -72,53 +72,21 @@ def benchmark_multi_query_cached_kv_attention( scale = float(1.0 / (head_size ** 0.5)) output = torch.empty( num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda') - - num_kv_tokens = sum(context_lens) - cu_context_lens = [0] - for context_len in context_lens: - cu_context_lens.append(cu_context_lens[-1] + context_len) - cpu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cpu') - cu_context_lens = cpu_context_lens.cuda() - ref_output = torch.empty_like(output) # Run our implementation. def run_ours(): - cache_ops.gather_cached_kv( - qkv, + attention_ops.multi_query_cached_kv_attention( + cu_query_lens, + output, + query, key_cache, value_cache, - cu_context_lens, - cpu_context_lens, + scale, block_tables, - ) - - _flash_attn_forward( - query, - key, - value, - ref_output, - cu_query_lens, - cu_context_lens, - max(query_lens), + context_len_tensor, + block_size, max_context_len, - dropout_p=0.0, - softmax_scale=scale, - causal=True, - return_softmax=False, ) - - # attention_ops.multi_query_cached_kv_attention( - # cu_query_lens, - # output, - # query, - # key_cache, - # value_cache, - # scale, - # block_tables, - # context_len_tensor, - # block_size, - # max_context_len, - # ) benchmark('Ours', run_ours) # Upper bound: Flash attention. diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index e63d7f814be3..a66f2c3daca7 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -4,7 +4,7 @@ from flash_attn.flash_attn_interface import _flash_attn_forward import torch -from cacheflow import attention_ops, cache_ops +from cacheflow import attention_ops MAX_SEQ_LEN = 4096 @@ -65,8 +65,7 @@ def ref_single_query_cached_kv_attention( def ref_multi_query_kv_attention( - cu_query_lens: List[int], - cu_context_lens: List[int], + cu_seq_lens: List[int], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -75,25 +74,21 @@ def ref_multi_query_kv_attention( head_size = query.shape[-1] scale = 1.0 / (head_size ** 0.5) - num_seqs = len(cu_query_lens) - 1 + num_seqs = len(cu_seq_lens) - 1 ref_outputs = [] for i in range(num_seqs): - query_start_idx = cu_query_lens[i] - query_end_idx = cu_query_lens[i + 1] - query_len = query_end_idx - query_start_idx - - context_start_idx = cu_context_lens[i] - context_end_idx = cu_context_lens[i + 1] - context_len = context_end_idx - context_start_idx + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx # Create attention mask - attn_mask = torch.triu(torch.ones(query_len, context_len), diagonal=1) * -1e5 + attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5 attn_mask = attn_mask.to(dtype=dtype, device='cuda') ref_output = ref_masked_attention( - query[query_start_idx:query_end_idx], - key[context_start_idx:context_end_idx], - value[context_start_idx:context_end_idx], + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], scale, attn_mask=attn_mask, ) @@ -232,58 +227,42 @@ def test_multi_query_kv_attention( head_size: int, dtype: torch.dtype, ) -> None: - query_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) - max_query_len = max(query_lens) - num_query_tokens = sum(query_lens) - cu_query_lens = [0] - for seq_len in query_lens: - cu_query_lens.append(cu_query_lens[-1] + seq_len) - cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda') + seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) + max_seq_len = max(seq_lens) + num_tokens = sum(seq_lens) - context_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) - max_context_len = max(context_lens) - num_context_tokens = sum(context_lens) - cu_context_lens = [0] - for seq_len in context_lens: - cu_context_lens.append(cu_context_lens[-1] + seq_len) - cu_context_lens = torch.tensor( - cu_context_lens, dtype=torch.int, device='cuda') + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda') scale = float(1.0 / (head_size ** 0.5)) qkv = torch.randn( - num_query_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') - query, _, _ = qkv.unbind(dim=1) - qkv = torch.randn( - num_context_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') - _, key, value = qkv.unbind(dim=1) - + num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') # Adjust the range of the values to reduce precision errors. - query = query / (head_size ** 0.5) - key = key / (head_size ** 0.5) - value = value / (head_size ** 0.5) + qkv = qkv / (head_size ** 0.5) + query, key, value = qkv.unbind(dim=1) output = torch.empty( - num_query_tokens, num_heads, head_size, dtype=dtype, device='cuda') + num_tokens, num_heads, head_size, dtype=dtype, device='cuda') _flash_attn_forward( query, key, value, output, - cu_query_lens, - cu_context_lens, - max_query_len, - max_context_len, + cu_seq_lens, + cu_seq_lens, + max_seq_len, + max_seq_len, dropout_p=0.0, softmax_scale=scale, causal=True, return_softmax=False, ) - cu_query_lens = cu_query_lens.cpu().tolist() - cu_context_lens = cu_context_lens.cpu().tolist() + cu_seq_lens = cu_seq_lens.cpu().tolist() ref_output = ref_multi_query_kv_attention( - cu_query_lens, - cu_context_lens, + cu_seq_lens, query, key, value, @@ -308,7 +287,7 @@ def test_multi_query_cached_kv_attention( qkv = torch.randn( num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') - query, key, value = qkv.unbind(dim=1) + query, _, _ = qkv.unbind(dim=1) x = 16 // torch.tensor([], dtype=dtype).element_size() key_block_shape = (num_heads, head_size // x, block_size, x) key_cache = torch.randn( @@ -323,13 +302,7 @@ def test_multi_query_cached_kv_attention( for query_len in query_lens ] max_context_len = max(context_lens) - cu_context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') - - cu_seqlens_k = [0] - for seq_len in context_lens: - cu_seqlens_k.append(cu_seqlens_k[-1] + seq_len) - cu_seqlens_k = torch.tensor( - cu_seqlens_k, dtype=torch.int, device='cuda') + context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size block_tables = [] @@ -345,49 +318,17 @@ def test_multi_query_cached_kv_attention( output = torch.empty( num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda') - # attention_ops.multi_query_cached_kv_attention( - # cu_query_lens, - # output, - # query, - # key_cache, - # value_cache, - # scale, - # block_tables, - # context_lens, - # block_size, - # max_context_len, - # ) - - old_key = key.clone() - old_value = value.clone() - - cache_ops.gather_cached_kv( - qkv, + attention_ops.multi_query_cached_kv_attention( + cu_query_lens, + output, + query, key_cache, value_cache, - cu_seqlens_k, - cu_seqlens_k.cpu(), + scale, block_tables, - ) - - # test if key and value are updated - assert not torch.allclose(key, old_key, atol=1e-3, rtol=1e-5) - assert not torch.allclose(value, old_value, atol=1e-3, rtol=1e-5) - - - _flash_attn_forward( - query, - key, - value, - output, - cu_query_lens, - cu_seqlens_k, - num_total_tokens, + context_lens, + block_size, max_context_len, - dropout_p=0.0, - softmax_scale=scale, - causal=True, - return_softmax=False, ) ref_output = ref_multi_query_cached_kv_attention( @@ -396,7 +337,7 @@ def test_multi_query_cached_kv_attention( key_cache, value_cache, block_tables, - cu_context_lens, + context_lens, dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) @@ -408,20 +349,20 @@ def test_attention(seed: int) -> None: # the test fails due to the precision issue. Re-run the test if it fails. torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - # for dtype in [torch.half, torch.float]: - # for block_size in [8, 16]: - # for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: - # print(f'Testing single_query_cached_kv_attention with ' - # f'dtype={dtype}, block_size={block_size}, ' - # f'head_size={head_size}') - # test_single_query_cached_kv_attention( - # num_tokens=37, - # num_heads=3, - # head_size=head_size, - # block_size=block_size, - # num_blocks=1024, - # dtype=dtype, - # ) + for dtype in [torch.half, torch.float]: + for block_size in [8, 16]: + for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: + print(f'Testing single_query_cached_kv_attention with ' + f'dtype={dtype}, block_size={block_size}, ' + f'head_size={head_size}') + test_single_query_cached_kv_attention( + num_tokens=37, + num_heads=3, + head_size=head_size, + block_size=block_size, + num_blocks=1024, + dtype=dtype, + ) # NOTE(siyuan): Same as above. Re-run the test if it fails. Also # note that the test is also more likely to fail due to the much From c5725d4cc8a599fbb63bd64e543bd4deff0b9dff Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 03:13:33 -0700 Subject: [PATCH 15/16] update --- csrc/cache_kernels.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ca01532d8868..786a928a501b 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -371,4 +371,3 @@ void gather_cached_kv( x); }); } - \ No newline at end of file From 70b51aa57be4e37c7215b96cf7b54aa81d21caa1 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Sun, 9 Apr 2023 14:59:45 -0700 Subject: [PATCH 16/16] rename and assert --- csrc/cache_kernels.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 786a928a501b..5f97af254142 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -239,11 +239,12 @@ __global__ void gather_cached_kv_kernel_optimized( const int block_idx = slot_idx / block_size; const int block_offset = slot_idx % block_size; - const int num_tokens = num_heads * head_size; + const int dim = num_heads * head_size; + assert(dim % 4 == 0); // this is true for known use cases const int unroll_factor = 4; - const int unrolled_num_tokens = num_tokens / unroll_factor; + const int unrolled_dim = dim / unroll_factor; - for (int i = threadIdx.x; i < unrolled_num_tokens; i += blockDim.x) + for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x) { int tgt_key_indices[unroll_factor]; int tgt_value_indices[unroll_factor]; @@ -255,7 +256,7 @@ __global__ void gather_cached_kv_kernel_optimized( #pragma unroll for (int j = 0; j < unroll_factor; ++j) { - int index = i + j * unrolled_num_tokens; + int index = i + j * unrolled_dim; const int tgt_key_idx = token_idx * key_stride + index; const int tgt_value_idx = token_idx * value_stride + index;