diff --git a/benchmark/benchmark_cache.py b/benchmark/benchmark_cache.py new file mode 100644 index 000000000000..6163d4815980 --- /dev/null +++ b/benchmark/benchmark_cache.py @@ -0,0 +1,81 @@ +import functools +import random +import time + +import torch + +from cacheflow import cache_ops + + +def benchmark(name, f, size: int, num_warmup = 10, num_iters = 100): + for _ in range(num_warmup): + f() + torch.cuda.synchronize() + + start = time.time() + for _ in range(num_iters): + f() + torch.cuda.synchronize() + end = time.time() + avg_time = (end - start) / num_iters + print(f'[Latency] {name}: {avg_time * 1000:.3f} ms') + print(f'[Throughput] {name}: {size / avg_time / 2 ** 30:.3f} GB/s') + + +@torch.inference_mode() +def test_gather_cached_kv( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, +) -> None: + print(f'num_tokens: {num_tokens}, num_heads: {num_heads}, ' + f'head_size: {head_size}, block_size: {block_size}, ' + f'num_blocks: {num_blocks}, dtype: {dtype}') + + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') + + qkv = torch.randn( + num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + _, key, value = qkv.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_cache = torch.randn( + size=value_cache_shape, dtype=dtype, device='cuda') + + # Run Flash attention. + def run(): + cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) + + benchmark('gather_cached_kv', run, + size=num_tokens * num_heads * head_size * 2 * qkv.element_size()) + + +if __name__ == '__main__': + BLOCK_SIZE = 8 + NUM_BLOCKS = 1024 + DTYPE = torch.half + + # LLaMA-13B and OPT-13B + NUM_HEADS = 40 + HEAD_SIZE = 128 + + run_benchmark = functools.partial( + test_gather_cached_kv, + num_heads=NUM_HEADS, + head_size=HEAD_SIZE, + block_size=BLOCK_SIZE, + num_blocks=NUM_BLOCKS, + dtype=DTYPE, + ) + + for i in range(6, 12): + run_benchmark(num_tokens=2 ** i) diff --git a/csrc/cache.cpp b/csrc/cache.cpp index 907736a981c9..9ae17bb2985c 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.cpp @@ -20,6 +20,13 @@ void reshape_and_cache( torch::Tensor& value_cache, torch::Tensor& slot_mapping); +void gather_cached_kv( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "swap_blocks", @@ -33,4 +40,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + m.def( + "gather_cached_kv", + &gather_cached_kv, + "Gather key and value from the cache into contiguous QKV tensors"); } diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3a34ba578980..5f97af254142 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -176,6 +176,124 @@ __global__ void reshape_and_cache_kernel( } } +// Grid: (num_blocks, block_size). +template +__global__ void gather_cached_kv_kernel( + scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size] + scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size] + const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const int x) { + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int num_tokens = num_heads * head_size; + for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { + const int tgt_key_idx = token_idx * key_stride + i; + const int tgt_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension + const int x_offset = head_offset % x; + + const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int src_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; + + key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); + value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); + } +} + +template +__global__ void gather_cached_kv_kernel_optimized( + scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size] + scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size] + const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int *__restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const int x) +{ + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int dim = num_heads * head_size; + assert(dim % 4 == 0); // this is true for known use cases + const int unroll_factor = 4; + const int unrolled_dim = dim / unroll_factor; + + for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x) + { + int tgt_key_indices[unroll_factor]; + int tgt_value_indices[unroll_factor]; + int src_key_indices[unroll_factor]; + int src_value_indices[unroll_factor]; + scalar_t keys_to_store[unroll_factor]; + scalar_t values_to_store[unroll_factor]; + + #pragma unroll + for (int j = 0; j < unroll_factor; ++j) + { + int index = i + j * unrolled_dim; + + const int tgt_key_idx = token_idx * key_stride + index; + const int tgt_value_idx = token_idx * value_stride + index; + + const int head_idx = index / head_size; + const int head_offset = index % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + + x_offset; + const int src_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; + + tgt_key_indices[j] = tgt_key_idx; + tgt_value_indices[j] = tgt_value_idx; + src_key_indices[j] = src_key_idx; + src_value_indices[j] = src_value_idx; + + keys_to_store[j] = __ldg(&key_cache[src_key_idx]); + values_to_store[j] = __ldg(&value_cache[src_value_idx]); + } + + #pragma unroll + for (int j = 0; j < unroll_factor; ++j) + { + key[tgt_key_indices[j]] = keys_to_store[j]; + value[tgt_value_indices[j]] = values_to_store[j]; + } + } +} + } // namespace cacheflow void reshape_and_cache( @@ -215,3 +333,42 @@ void reshape_and_cache( x); }); } + + +void gather_cached_kv( + torch::Tensor& key, // [out] [num_tokens, num_heads, head_size] + torch::Tensor& value, // [out] [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping) // [in] [num_tokens] +{ + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + key.scalar_type(), + "gather_cached_kv_kernel_optimized", + [&] { + cacheflow::gather_cached_kv_kernel_optimized<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + slot_mapping.data_ptr(), + key_stride, + value_stride, + num_heads, + head_size, + block_size, + x); + }); +} diff --git a/tests/kernels/cache.py b/tests/kernels/cache.py index 89f14cca82a2..f444ac16a49d 100644 --- a/tests/kernels/cache.py +++ b/tests/kernels/cache.py @@ -99,6 +99,47 @@ def test_reshape_and_cache( assert torch.allclose(value_cache, cloned_value_cache) +def test_gather_cached_kv( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, +) -> None: + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') + + qkv = torch.randn( + num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + _, key, value = qkv.unbind(dim=1) + + qkv_clone = qkv.clone() + _, cloned_key, cloned_value = qkv_clone.unbind(dim=1) + + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_cache = torch.randn( + size=value_cache_shape, dtype=dtype, device='cuda') + + cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) + + # Reference implementation. + for i in range(num_tokens): + reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x) + block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') + block_offset = slot_mapping[i] % block_size + reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :] + cloned_value[i] = value_cache[block_idx, :, :, block_offset] + + assert torch.allclose(key, cloned_key) + assert torch.allclose(value, cloned_value) + + @torch.inference_mode() def test_cache() -> None: test_copy_blocks( @@ -107,6 +148,9 @@ def test_cache() -> None: test_reshape_and_cache( num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, dtype=torch.half) + test_gather_cached_kv( + num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, + dtype=torch.half) if __name__ == '__main__':