diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index 3ff162170f25..2659afcdc74a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -128,7 +128,8 @@ def triton_scaled_mm(input: torch.Tensor, bias: Optional[torch.Tensor] = None, block_size_m: int = 32, block_size_n: int = 32, - block_size_k: int = 32) -> torch.Tensor: + block_size_k: int = 32, + use_heuristic=True) -> torch.Tensor: M, K = input.shape N = weight.shape[1] @@ -152,6 +153,20 @@ def triton_scaled_mm(input: torch.Tensor, has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 + if use_heuristic: + is_small_N = N < 8192 + next_power_of_2_M = max(32, triton.next_power_of_2(M)) + if next_power_of_2_M <= 32: + tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256) + elif next_power_of_2_M <= 64: + tile_shape = (64, 64, 256) + elif next_power_of_2_M <= 128: + tile_shape = (64, 128, 128) + else: + tile_shape = (128, 128, 128) + + block_size_m, block_size_n, block_size_k = tile_shape + block_size_sa = 1 if has_scalar(scale_a) else block_size_m block_size_sb = 1 if has_scalar(scale_b) else block_size_n