From 5675c6bb16badb3b2624beaf490452c32a587aa2 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 2 Jan 2025 17:50:28 -0600 Subject: [PATCH 1/3] Change defeault block size for triton_scaled_mm to 128 for 4-5x speedup Signed-off-by: Randall Smith --- .../quantization/compressed_tensors/triton_scaled_mm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index 3ff162170f25..2bc90831894b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -126,9 +126,9 @@ def triton_scaled_mm(input: torch.Tensor, scale_b: torch.Tensor, out_dtype: Type[torch.dtype], bias: Optional[torch.Tensor] = None, - block_size_m: int = 32, - block_size_n: int = 32, - block_size_k: int = 32) -> torch.Tensor: + block_size_m: int = 128, + block_size_n: int = 128, + block_size_k: int = 128) -> torch.Tensor: M, K = input.shape N = weight.shape[1] From a45f569e4208f7c3a05ae2a3313f9a7fc861d4fd Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Fri, 3 Jan 2025 16:41:39 -0600 Subject: [PATCH 2/3] Use heuristic based on cutlass_gemm_sm90_int8_dispatch Signed-off-by: Randall Smith --- .../compressed_tensors/triton_scaled_mm.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index 2bc90831894b..33b637100d87 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -126,9 +126,9 @@ def triton_scaled_mm(input: torch.Tensor, scale_b: torch.Tensor, out_dtype: Type[torch.dtype], bias: Optional[torch.Tensor] = None, - block_size_m: int = 128, - block_size_n: int = 128, - block_size_k: int = 128) -> torch.Tensor: + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32) -> torch.Tensor: M, K = input.shape N = weight.shape[1] @@ -152,6 +152,22 @@ def triton_scaled_mm(input: torch.Tensor, has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 + is_small_N = N < 8192 + next_power_of_2_M = max(32, triton.next_power_of_2(M)) + if next_power_of_2_M <= 32: + if is_small_N: + shape = (64, 64, 256) + else: + shape = (64, 128, 256) + elif next_power_of_2_M <= 64: + shape = (64, 64, 256) + elif next_power_of_2_M <= 128: + shape = (64, 128, 128) + else: + shape = (128, 128, 128) + + block_size_m, block_size_n, block_size_k = shape + block_size_sa = 1 if has_scalar(scale_a) else block_size_m block_size_sb = 1 if has_scalar(scale_b) else block_size_n From eb8126ef5a7c45c9e24d9d9570243c81eb75e1e7 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Tue, 7 Jan 2025 15:45:18 -0600 Subject: [PATCH 3/3] Use heuristic to pick block size for better performance across input/output/batch sizes Signed-off-by: Randall Smith --- .../compressed_tensors/triton_scaled_mm.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index 33b637100d87..2659afcdc74a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -128,7 +128,8 @@ def triton_scaled_mm(input: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size_m: int = 32, block_size_n: int = 32, - block_size_k: int = 32) -> torch.Tensor: + block_size_k: int = 32, + use_heuristic=True) -> torch.Tensor: M, K = input.shape N = weight.shape[1] @@ -152,21 +153,19 @@ def triton_scaled_mm(input: torch.Tensor, has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 - is_small_N = N < 8192 - next_power_of_2_M = max(32, triton.next_power_of_2(M)) - if next_power_of_2_M <= 32: - if is_small_N: - shape = (64, 64, 256) + if use_heuristic: + is_small_N = N < 8192 + next_power_of_2_M = max(32, triton.next_power_of_2(M)) + if next_power_of_2_M <= 32: + tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256) + elif next_power_of_2_M <= 64: + tile_shape = (64, 64, 256) + elif next_power_of_2_M <= 128: + tile_shape = (64, 128, 128) else: - shape = (64, 128, 256) - elif next_power_of_2_M <= 64: - shape = (64, 64, 256) - elif next_power_of_2_M <= 128: - shape = (64, 128, 128) - else: - shape = (128, 128, 128) - - block_size_m, block_size_n, block_size_k = shape + tile_shape = (128, 128, 128) + + block_size_m, block_size_n, block_size_k = tile_shape block_size_sa = 1 if has_scalar(scale_a) else block_size_m block_size_sb = 1 if has_scalar(scale_b) else block_size_n