From b81da8699f2ca660b7265a259507fc045996b637 Mon Sep 17 00:00:00 2001 From: "Allen.Dou" Date: Tue, 13 Aug 2024 14:08:35 +0800 Subject: [PATCH 1/3] To compare the matmul performance of bfloat16 and fp8. --- tests/kernels/test_cutlass.py | 64 +++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 33c4ea8dd92e..90e974e997c3 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -2,6 +2,7 @@ Run `pytest tests/kernels/test_cutlass.py`. """ +import time from typing import Optional, Type import pytest @@ -46,6 +47,16 @@ def baseline_scaled_mm(a: torch.Tensor, return output +def baseline_torch_mm(a: torch.Tensor, + b: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = torch.mm(a, b) + if bias is not None: + output = output + bias + + return output + + def cutlass_fp8_gemm_helper(m: int, n: int, k: int, @@ -77,6 +88,46 @@ def cutlass_fp8_gemm_helper(m: int, assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2) +def compare_cutlass_fp8_torch_mm_helper( + m: int, + n: int, + k: int, + use_bias: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda", + display: bool = False, +): + if use_bias: + bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 + else: + bias = None + a_f16 = torch.randn((m, k), dtype=torch.bfloat16, device=device) + b_f16 = torch.randn((n, k), dtype=torch.bfloat16, device=device) + + time_s1 = time.time() + qa, scale_a = ops.scaled_fp8_quant(a_f16, None) + scale_time = time.time() - time_s1 + qb, scale_b = ops.scaled_fp8_quant(b_f16, None) + + b_f16 = b_f16.t() + qb = qb.t() + + time1 = time.time() + baseline = baseline_torch_mm(a_f16, b_f16, bias) + time2 = time.time() + out = ops.cutlass_scaled_mm(qa, qb, scale_a, scale_b, out_dtype, bias) + time3 = time.time() + + torch_mm_time = time2 - time1 + cutlass_mm_time = time3 - time2 + + if display: + print(f"{torch_mm_time=:.10f}", f"{scale_time=:.10f}", + f"{cutlass_mm_time=:.10f}") + + assert baseline.shape == out.shape + + def cutlass_int8_gemm_helper(m: int, n: int, k: int, @@ -122,6 +173,19 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias) +@pytest.mark.parametrize("m", [4096]) +@pytest.mark.parametrize("n", [4096]) +@pytest.mark.parametrize("k", [4096]) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_compare_cutlass_fp8_torch_mm(m: int, n: int, k: int, use_bias: bool): + for _ in range(20): + # warm up + compare_cutlass_fp8_torch_mm_helper(m, n, k, use_bias, display=False) + compare_cutlass_fp8_torch_mm_helper(m, n, k, use_bias, display=True) + + @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 8192, 16384, 256, 1024]) @pytest.mark.parametrize("k", [128, 496, 1024]) From df75edc9a2ad1d902fa92df1b0b4825a9221f120 Mon Sep 17 00:00:00 2001 From: "Allen.Dou" Date: Tue, 13 Aug 2024 15:08:48 +0800 Subject: [PATCH 2/3] update. --- tests/kernels/test_cutlass.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 90e974e997c3..517ec6ce02d0 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -103,6 +103,7 @@ def compare_cutlass_fp8_torch_mm_helper( bias = None a_f16 = torch.randn((m, k), dtype=torch.bfloat16, device=device) b_f16 = torch.randn((n, k), dtype=torch.bfloat16, device=device) + torch.cuda.synchronize() time_s1 = time.time() qa, scale_a = ops.scaled_fp8_quant(a_f16, None) @@ -114,8 +115,10 @@ def compare_cutlass_fp8_torch_mm_helper( time1 = time.time() baseline = baseline_torch_mm(a_f16, b_f16, bias) + torch.cuda.synchronize() time2 = time.time() out = ops.cutlass_scaled_mm(qa, qb, scale_a, scale_b, out_dtype, bias) + torch.cuda.synchronize() time3 = time.time() torch_mm_time = time2 - time1 @@ -176,7 +179,7 @@ def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, @pytest.mark.parametrize("m", [4096]) @pytest.mark.parametrize("n", [4096]) @pytest.mark.parametrize("k", [4096]) -@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_bias", [False, True]) @pytest.mark.skipif(capability < 89, reason="FP8 is not supported on this GPU type.") def test_compare_cutlass_fp8_torch_mm(m: int, n: int, k: int, use_bias: bool): From ae400937dd6a3ec836d7560fa66bffa556fe2579 Mon Sep 17 00:00:00 2001 From: "Allen.Dou" Date: Tue, 13 Aug 2024 21:35:05 +0800 Subject: [PATCH 3/3] update. --- tests/kernels/test_cutlass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index 517ec6ce02d0..e1c972ea476c 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -125,7 +125,7 @@ def compare_cutlass_fp8_torch_mm_helper( cutlass_mm_time = time3 - time2 if display: - print(f"{torch_mm_time=:.10f}", f"{scale_time=:.10f}", + print(f"{use_bias=}, {torch_mm_time=:.10f}", f"{scale_time=:.10f}", f"{cutlass_mm_time=:.10f}") assert baseline.shape == out.shape