diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index f2c7f2c809e..50eaa92f59b 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -3,6 +3,7 @@ import math import random import time +from collections.abc import Callable import pytest import torch @@ -10,6 +11,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask from vllm.attention.backends.xformers import _make_alibi_bias +from vllm.attention.ops.chunked_prefill_paged_decode import ( + chunked_prefill_paged_decode) from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -24,6 +27,8 @@ SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] +OPS = [chunked_prefill_paged_decode, context_attention_fwd] + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @@ -32,6 +37,7 @@ @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) +@pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, @@ -41,6 +47,7 @@ def test_contexted_kv_attention( dtype: torch.dtype, kv_cache_dtype: str, device: str, + op: Callable, ) -> None: if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( @@ -65,6 +72,9 @@ def test_contexted_kv_attention( block_size = 32 max_block_per_request = 64 query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + # ensure one sequence in batch is a decode + query_lens[-1] = 1 + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv @@ -144,36 +154,36 @@ def test_contexted_kv_attention( # Warm up the Triton kernel by calling it once before actually measuring # generation time - context_attention_fwd(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op(query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window) torch.cuda.synchronize() start_time = time.time() - context_attention_fwd(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op(query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") @@ -228,7 +238,7 @@ def test_contexted_kv_attention( end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) - atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 + atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -238,6 +248,7 @@ def test_contexted_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention_alibi( num_heads: int, @@ -246,6 +257,7 @@ def test_contexted_kv_attention_alibi( dtype: torch.dtype, kv_cache_dtype: str, device: str, + op: Callable, ) -> None: if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( @@ -375,36 +387,36 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # Warm up the Triton kernel by calling it once before actually measuring # generation time - context_attention_fwd(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op(query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes) torch.cuda.synchronize() start_time = time.time() - context_attention_fwd(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op(query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") @@ -503,6 +515,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) +@pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention_f32( num_heads: int, @@ -512,9 +525,11 @@ def test_contexted_kv_attention_f32( dtype: torch.dtype, kv_cache_dtype: str, device: str, + op: Callable, ) -> None: test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, - sliding_window, dtype, kv_cache_dtype, device) + sliding_window, dtype, kv_cache_dtype, device, + op) @pytest.mark.optional @@ -524,6 +539,7 @@ def test_contexted_kv_attention_f32( @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention_alibi_f32( num_heads: int, @@ -532,6 +548,7 @@ def test_contexted_kv_attention_alibi_f32( dtype: torch.dtype, kv_cache_dtype: str, device: str, + op: Callable, ) -> None: test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size, - dtype, kv_cache_dtype, device) + dtype, kv_cache_dtype, device, op) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py new file mode 100644 index 00000000000..807a270b43d --- /dev/null +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + +from .prefix_prefill import context_attention_fwd + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def kernel_paged_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.constexpr, # int + query_stride_0: tl.constexpr, # int + query_stride_1: tl.constexpr, # int, should be equal to head_size + output_stride_0: tl.constexpr, # int + output_stride_1: tl.constexpr, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + x: tl.constexpr, # int + stride_k_cache_0: tl.constexpr, # int + stride_k_cache_1: tl.constexpr, # int + stride_k_cache_2: tl.constexpr, # int + stride_k_cache_3: tl.constexpr, # int + stride_k_cache_4: tl.constexpr, # int + stride_v_cache_0: tl.constexpr, # int + stride_v_cache_1: tl.constexpr, # int + stride_v_cache_2: tl.constexpr, # int + stride_v_cache_3: tl.constexpr, # int + filter_by_query_len: tl.constexpr, # bool + query_start_len_ptr, # [num_seqs+1] +): + seq_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + kv_head_idx = query_head_idx // num_queries_per_kv + + if filter_by_query_len: + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + + 1) + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + if cur_batch_query_len > 1: + return + else: + cur_batch_in_all_start_index = seq_idx + + query_offset = (cur_batch_in_all_start_index * query_stride_0 + + query_head_idx * query_stride_1) + + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # Q : (HEAD_SIZE,) + Q = tl.load( + query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED), + mask=dim_mask, + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([1], float("-inf"), dtype=tl.float32) + L = tl.full([1], 1.0, dtype=tl.float32) + acc = tl.zeros([HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[:, None] * stride_v_cache_2 + + offs_n[None, :] * stride_v_cache_3) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + (offs_d[:, None] // x) * stride_k_cache_2 + + offs_n[None, :] * stride_k_cache_3 + + (offs_d[:, None] % x) * stride_k_cache_4) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (HEAD_SIZE, BLOCK_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[:, None], + other=0.0) + + if V_load.dtype.is_fp8(): + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + tmp = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) + mask_new = tmp < boundary + # S : (BLOCK_SIZE,) + S = tl.where(mask_new, 0.0, float("-inf")).to(tl.float32) + S += scale * tl.sum(K * Q[:, None], axis=0) + + if SLIDING_WINDOW > 0: + S = tl.where((seq_len - 1 - tmp) < SLIDING_WINDOW, S, -10000) + + if USE_ALIBI_SLOPES: + S += alibi_slope * (tmp - seq_len + 1) + + # compute running maximum + # m_j : (1,) + m_j = tl.maximum(M, tl.max(S, axis=0)) + + # P : (BLOCK_SIZE,) + P = tl.exp(S - m_j) + + # l_j : (1,) + l_j = tl.sum(P, axis=0) + + # alpha : (1, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_SIZE,) + acc = acc * alpha + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_SIZE,) + acc += tl.sum(V * P[None, :], axis=1) + + # epilogue + acc = acc / L + + output_offset = (cur_batch_in_all_start_index * output_stride_0 + + query_head_idx * output_stride_1) + + tl.store(output_ptr + output_offset + tl.arange(0, HEAD_SIZE_PADDED), + acc, + mask=dim_mask) + + +def chunked_prefill_paged_decode( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_table, + query_start_loc, + seq_lens, + max_query_len, + k_scale, + v_scale, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, +): + + if sm_scale is None: + sm_scale = 1.0 / (query.shape[1]**0.5) + + use_alibi_slopes = alibi_slopes is not None + + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if max_query_len > 1: + context_attention_fwd( + q=query, + k=key, + v=value, + o=output, + kv_cache_dtype=kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=block_table, + b_start_loc=query_start_loc, + b_seq_len=seq_lens, + max_input_len=max_query_len, + k_scale=k_scale, + v_scale=v_scale, + alibi_slopes=alibi_slopes, + sliding_window=sliding_window, + sm_scale=sm_scale, + skip_decode=True, + ) + + block_size = value_cache.shape[3] + num_seqs = len(seq_lens) + num_query_heads = query.shape[1] + num_queries_per_kv = query.shape[1] // key.shape[1] + head_size = query.shape[2] + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert key_cache.dtype == torch.uint8 + assert value_cache.dtype == torch.uint8 + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + key_cache = key_cache.view(target_dtype) + value_cache = value_cache.view(target_dtype) + + kernel_paged_attention_2d[( + num_seqs, + num_query_heads, + )]( + output_ptr=output, + query_ptr=query, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + block_tables_ptr=block_table, + seq_lens_ptr=seq_lens, + alibi_slopes_ptr=alibi_slopes, + scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=query.stride(0), + query_stride_1=query.stride(1), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + SLIDING_WINDOW=sliding_window, + x=key_cache.shape[4], + stride_k_cache_0=key_cache.stride(0), + stride_k_cache_1=key_cache.stride(1), + stride_k_cache_2=key_cache.stride(2), + stride_k_cache_3=key_cache.stride(3), + stride_k_cache_4=key_cache.stride(4), + stride_v_cache_0=value_cache.stride(0), + stride_v_cache_1=value_cache.stride(1), + stride_v_cache_2=value_cache.stride(2), + stride_v_cache_3=value_cache.stride(3), + filter_by_query_len=True, + query_start_len_ptr=query_start_loc, + ) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 103c408ebbf..e85ec605ad2 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -64,7 +64,9 @@ def _fwd_kernel( BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, + SKIP_DECODE: tl.constexpr, ): + cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -78,6 +80,9 @@ def _fwd_kernel( cur_batch_in_all_start_index) cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + if SKIP_DECODE and cur_batch_query_len == 1: + return + # start position inside of the query # generally, N goes over kv, while M goes over query_len block_start_loc = BLOCK_M * start_m @@ -500,6 +505,7 @@ def _fwd_kernel_alibi( BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, ): # attn_bias[] cur_batch = tl.program_id(0) @@ -518,6 +524,9 @@ def _fwd_kernel_alibi( cur_batch_in_all_start_index) cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + if SKIP_DECODE and cur_batch_query_len == 1: + return + block_start_loc = BLOCK_M * start_m # initialize offsets @@ -721,7 +730,8 @@ def context_attention_fwd(q, v_scale: torch.Tensor, alibi_slopes=None, sliding_window=None, - sm_scale=None): + sm_scale=None, + skip_decode=False): q_dtype_is_f32 = q.dtype is torch.float32 # need to reduce num. blocks when using fp32 @@ -823,6 +833,7 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, + SKIP_DECODE=skip_decode, num_warps=NUM_WARPS, num_stages=1, ) @@ -875,6 +886,7 @@ def context_attention_fwd(q, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, num_warps=NUM_WARPS, num_stages=1, ) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index a625d99f4a1..640c3b3d4fb 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -6,8 +6,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.ops.chunked_prefill_paged_decode import ( + chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import ( FlashAttentionMetadata, FlashAttentionMetadataBuilder) @@ -156,20 +157,22 @@ def forward( ) # Compute attention and update output up to `num_actual_tokens`. - context_attention_fwd(q=query[:num_actual_tokens], - k=key[:num_actual_tokens], - v=value[:num_actual_tokens], - o=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - k_cache=key_cache, - v_cache=value_cache, - b_loc=attn_metadata.block_table, - b_start_loc=attn_metadata.query_start_loc, - b_seq_len=attn_metadata.seq_lens, - max_input_len=attn_metadata.max_query_len, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale) + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=attn_metadata.block_table, + query_start_loc=attn_metadata.query_start_loc, + seq_lens=attn_metadata.seq_lens, + max_query_len=attn_metadata.max_query_len, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale) + return output