Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:


@torch.inference_mode()
def test_silu_and_mul(
def run_silu_and_mul(
num_tokens: int,
d: int,
dtype: torch.dtype,
Expand All @@ -22,9 +22,9 @@ def test_silu_and_mul(
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)


if __name__ == '__main__':
def test_silu_and_mul() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 13824]:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
test_silu_and_mul(num_tokens, d, dtype)
run_silu_and_mul(num_tokens, d, dtype)
28 changes: 14 additions & 14 deletions tests/kernels/attention.py → tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cacheflow import attention_ops

MAX_SEQ_LEN = 4096
TEST_SEED = 0


def ref_masked_attention(
Expand Down Expand Up @@ -155,7 +156,8 @@ def ref_multi_query_cached_kv_attention(
return ref_output


def test_single_query_cached_kv_attention(
@torch.inference_mode()
def run_single_query_cached_kv_attention(
num_tokens: int,
num_heads: int,
head_size: int,
Expand Down Expand Up @@ -223,7 +225,8 @@ def test_single_query_cached_kv_attention(
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


def test_multi_query_kv_attention(
@torch.inference_mode()
def run_multi_query_kv_attention(
num_seqs: int,
num_heads: int,
head_size: int,
Expand Down Expand Up @@ -264,19 +267,16 @@ def test_multi_query_kv_attention(
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


@torch.inference_mode()
def test_attention(seed: int) -> None:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
# the test fails due to the precision issue. Re-run the test if it fails.
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_single_query_cached_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16]:
for block_size in [8, 16, 32, 64]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
test_single_query_cached_kv_attention(
run_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
head_size=head_size,
Expand All @@ -285,17 +285,17 @@ def test_attention(seed: int) -> None:
dtype=dtype,
)


def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
test_multi_query_kv_attention(
run_multi_query_kv_attention(
num_seqs=5,
num_heads=3,
head_size=head_size,
dtype=dtype,
)


if __name__ == '__main__':
test_attention(seed=0)
30 changes: 18 additions & 12 deletions tests/kernels/cache.py → tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from cacheflow import cache_ops


def test_copy_blocks(
@torch.inference_mode()
def run_copy_blocks(
num_mappings: int,
num_layers: int,
num_heads: int,
Expand Down Expand Up @@ -60,7 +61,8 @@ def test_copy_blocks(
assert torch.allclose(value_cache, cloned_value_cache)


def test_reshape_and_cache(
@torch.inference_mode()
def run_reshape_and_cache(
num_tokens: int,
num_heads: int,
head_size: int,
Expand Down Expand Up @@ -99,7 +101,8 @@ def test_reshape_and_cache(
assert torch.allclose(value_cache, cloned_value_cache)


def test_gather_cached_kv(
@torch.inference_mode()
def run_gather_cached_kv(
num_tokens: int,
num_heads: int,
head_size: int,
Expand Down Expand Up @@ -140,19 +143,22 @@ def test_gather_cached_kv(
assert torch.allclose(value, cloned_value)


@torch.inference_mode()
def test_cache() -> None:
def test_copy_blocks() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
test_copy_blocks(
run_copy_blocks(
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
block_size=8, num_blocks=1024, dtype=dtype)
test_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)
test_gather_cached_kv(


def test_reshape_and_cache() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)


if __name__ == '__main__':
test_cache()
def test_gather_cached_kv() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_gather_cached_kv(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def forward(self, hidden_states):


@torch.inference_mode()
def test_rms_norm(
def run_rms_norm(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
Expand All @@ -41,13 +41,13 @@ def test_rms_norm(
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)


if __name__ == '__main__':
def test_rms_norm() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
f'{num_tokens}, hidden_size={hidden_size}')
test_rms_norm(
run_rms_norm(
num_tokens=num_tokens,
hidden_size=hidden_size,
dtype=dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(


@torch.inference_mode()
def test_rotary_embedding_neox(
def run_rotary_embedding_neox(
num_tokens: int,
num_heads: int,
head_size: int,
Expand Down Expand Up @@ -128,15 +128,15 @@ def test_rotary_embedding_neox(
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)


if __name__ == '__main__':
def test_rotary_embedding_neox() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}')
test_rotary_embedding_neox(
run_rotary_embedding_neox(
num_tokens=2145,
num_heads=5,
head_size=head_size,
max_position=8192,
rotary_dim=int(head_size * 0.25),
rotary_dim=head_size,
dtype=dtype,
)