Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Run `pytest tests/kernels/test_cutlass.py`.
"""
import time
from typing import Optional, Type

import pytest
Expand Down Expand Up @@ -46,6 +47,16 @@ def baseline_scaled_mm(a: torch.Tensor,
return output


def baseline_torch_mm(a: torch.Tensor,
b: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = torch.mm(a, b)
if bias is not None:
output = output + bias

return output


def cutlass_fp8_gemm_helper(m: int,
n: int,
k: int,
Expand Down Expand Up @@ -77,6 +88,49 @@ def cutlass_fp8_gemm_helper(m: int,
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)


def compare_cutlass_fp8_torch_mm_helper(
m: int,
n: int,
k: int,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda",
display: bool = False,
):
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else:
bias = None
a_f16 = torch.randn((m, k), dtype=torch.bfloat16, device=device)
b_f16 = torch.randn((n, k), dtype=torch.bfloat16, device=device)
torch.cuda.synchronize()

time_s1 = time.time()
qa, scale_a = ops.scaled_fp8_quant(a_f16, None)
scale_time = time.time() - time_s1
qb, scale_b = ops.scaled_fp8_quant(b_f16, None)

b_f16 = b_f16.t()
qb = qb.t()

time1 = time.time()
baseline = baseline_torch_mm(a_f16, b_f16, bias)
torch.cuda.synchronize()
time2 = time.time()
out = ops.cutlass_scaled_mm(qa, qb, scale_a, scale_b, out_dtype, bias)
torch.cuda.synchronize()
time3 = time.time()

torch_mm_time = time2 - time1
cutlass_mm_time = time3 - time2

if display:
print(f"{use_bias=}, {torch_mm_time=:.10f}", f"{scale_time=:.10f}",
f"{cutlass_mm_time=:.10f}")

assert baseline.shape == out.shape


def cutlass_int8_gemm_helper(m: int,
n: int,
k: int,
Expand Down Expand Up @@ -122,6 +176,19 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)


@pytest.mark.parametrize("m", [4096])
@pytest.mark.parametrize("n", [4096])
@pytest.mark.parametrize("k", [4096])
@pytest.mark.parametrize("use_bias", [False, True])
@pytest.mark.skipif(capability < 89,
reason="FP8 is not supported on this GPU type.")
def test_compare_cutlass_fp8_torch_mm(m: int, n: int, k: int, use_bias: bool):
for _ in range(20):
# warm up
compare_cutlass_fp8_torch_mm_helper(m, n, k, use_bias, display=False)
compare_cutlass_fp8_torch_mm_helper(m, n, k, use_bias, display=True)


@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 8192, 16384, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024])
Expand Down
Loading