Skip to content
135 changes: 76 additions & 59 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import math
import random
import time
from collections.abc import Callable

import pytest
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
Expand All @@ -24,6 +27,8 @@
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]

OPS = [chunked_prefill_paged_decode, context_attention_fwd]


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
Expand All @@ -32,6 +37,7 @@
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
Expand All @@ -41,6 +47,7 @@ def test_contexted_kv_attention(
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:

if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
Expand All @@ -65,6 +72,9 @@ def test_contexted_kv_attention(
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ensure one sequence in batch is a decode
query_lens[-1] = 1

ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
Expand Down Expand Up @@ -144,36 +154,36 @@ def test_contexted_kv_attention(

# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
start_time = time.time()
context_attention_fwd(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
Expand Down Expand Up @@ -228,7 +238,7 @@ def test_contexted_kv_attention(
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)


Expand All @@ -238,6 +248,7 @@ def test_contexted_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
num_heads: int,
Expand All @@ -246,6 +257,7 @@ def test_contexted_kv_attention_alibi(
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:

if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
Expand Down Expand Up @@ -375,36 +387,36 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:

# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
start_time = time.time()
context_attention_fwd(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
Expand Down Expand Up @@ -503,6 +515,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_f32(
num_heads: int,
Expand All @@ -512,9 +525,11 @@ def test_contexted_kv_attention_f32(
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
sliding_window, dtype, kv_cache_dtype, device)
sliding_window, dtype, kv_cache_dtype, device,
op)


@pytest.mark.optional
Expand All @@ -524,6 +539,7 @@ def test_contexted_kv_attention_f32(
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
num_heads: int,
Expand All @@ -532,6 +548,7 @@ def test_contexted_kv_attention_alibi_f32(
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
dtype, kv_cache_dtype, device)
dtype, kv_cache_dtype, device, op)
Loading