From 00f3924e5405ff8ec8c7facab2e7cf2a3e0d87fc Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 17 Jun 2025 14:48:00 -0400 Subject: [PATCH 01/13] Michael changes Signed-off-by: Luka Govedic --- .../kernels/bench_per_token_quant_fp8.py | 84 ++++++++++++++++ .../model_executor/layers/fp8_quantization.py | 96 +++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 benchmarks/kernels/bench_per_token_quant_fp8.py create mode 100644 vllm/model_executor/layers/fp8_quantization.py diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py new file mode 100644 index 000000000000..d978a685fca1 --- /dev/null +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +import itertools + +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.triton_utils import triton + + +def torch_per_token_quant_fp8( + input: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return ops.dynamic_per_token_quant_fp8(input) + + +def cuda_per_token_quant_fp8( + input: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + scale = torch.empty((input.shape[0], 1), device=input.device, dtype=torch.float32) + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() + output = torch.empty(input.shape, device=input.device, dtype=out_dtype) + scale_ub = None + torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + return output, scale + + +def calculate_diff(batch_size: int, seq_len: int): + """Calculate difference between Triton and CUDA implementations.""" + device = torch.device("cuda") + x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) + + torch_out, torch_scale = torch_per_token_quant_fp8(x) + cuda_out, cuda_scale = cuda_per_token_quant_fp8(x) + + if torch.allclose( + cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 + ) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [1, 16, 32, 64, 128] +seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] + +configs = list(itertools.product(batch_size_range, seq_len_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=configs, + line_arg="provider", + line_vals=["torch", "cuda"], + line_names=["Torch", "CUDA"], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="per-token-dynamic-quant-fp8-performance", + args={}, + ) +) +def benchmark_quantization(batch_size, seq_len, provider): + dtype = torch.float16 + device = torch.device("cuda") + + x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + fn = lambda: torch_per_token_quant_fp8(x.clone()) + elif provider == "cuda": + fn = lambda: cuda_per_token_quant_fp8(x.clone()) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + calculate_diff(batch_size=4, seq_len=4096) + benchmark_quantization.run(print_data=True) diff --git a/vllm/model_executor/layers/fp8_quantization.py b/vllm/model_executor/layers/fp8_quantization.py new file mode 100644 index 000000000000..0f320875c894 --- /dev/null +++ b/vllm/model_executor/layers/fp8_quantization.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform + +# Using the default value (240.0) from pytorch will cause accuracy +# issue on dynamic quantization models. Here use 224.0 for rocm. +_FP8_DTYPE = current_platform.fp8_dtype() +_FP8_MAX = 224.0 if current_platform.is_rocm() else torch.finfo(_FP8_DTYPE).max +_FP8_MIN = -224.0 if current_platform.is_rocm() else torch.finfo( + _FP8_DTYPE).min +_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) + + +@CustomOp.register("quant_fp8_per_token") +class QuantFP8PerToken(CustomOp): + """ + Quantize input tensor to dynamic per-token FP8 and return quantized + tensor and scale. + + Args: + x: The input tensor to be quantized to FP8 + scale_ub: Optional upper bound for scaling factor + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + + def forward_native( + self, + x: torch.Tensor, + scale_ub: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + x_token_max, _ = x.abs().max(dim=-1) + x_token_max = x_token_max.to(torch.float32) + if scale_ub is not None: + x_token_max = x_token_max.clamp(max=scale_ub) + scales = (x_token_max / _FP8_MAX).unsqueeze(-1) + scales = scales.clamp(min=_FP8_MIN_SCALING_FACTOR) + + out = x.to(torch.float32) / scales + out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + return out, scales + + def forward_cuda( + self, + x: torch.Tensor, + scale_ub: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + return ops.scaled_fp8_quant(x, + scale_ub=scale_ub, + use_per_token_if_dynamic=True) + + +@CustomOp.register("quant_fp8_per_tensor") +class QuantFP8PerTensor(CustomOp): + """ + Quantize input tensor to per-tensor FP8 and return quantized tensor and + scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + + def forward_native(self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = torch.zeros(1, device=x.device, dtype=torch.float32) + x_max = x.abs().max().to(torch.float32) + scale = x_max / _FP8_MAX + + out = (x.to(torch.float32) * scale.reciprocal()).clamp( + _FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + return out, scale.view((1, )) + + def forward_cuda(self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + return ops.scaled_fp8_quant(x, scale=scale) From fcdfe98208a92c6a1031ffb97a48c92fb928f6ab Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Tue, 17 Jun 2025 15:42:40 -0400 Subject: [PATCH 02/13] Cleanup constants/utils Signed-off-by: Luka Govedic --- vllm/model_executor/layers/fp8_quantization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fp8_quantization.py b/vllm/model_executor/layers/fp8_quantization.py index 0f320875c894..7279e40425de 100644 --- a/vllm/model_executor/layers/fp8_quantization.py +++ b/vllm/model_executor/layers/fp8_quantization.py @@ -8,11 +8,11 @@ from vllm.platforms import current_platform # Using the default value (240.0) from pytorch will cause accuracy -# issue on dynamic quantization models. Here use 224.0 for rocm. +# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm. _FP8_DTYPE = current_platform.fp8_dtype() -_FP8_MAX = 224.0 if current_platform.is_rocm() else torch.finfo(_FP8_DTYPE).max -_FP8_MIN = -224.0 if current_platform.is_rocm() else torch.finfo( - _FP8_DTYPE).min +_FP8_FINFO = torch.finfo(_FP8_DTYPE) +_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max +_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) From b79fb8452c3489aadf123fd6466dba3b1975fa72 Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Tue, 8 Jul 2025 12:36:46 -0400 Subject: [PATCH 03/13] Add QuantFP8 (CustomOp subclass) and use it for FP8 GEMMs. Also fix fusion tests. Signed-off-by: Luka Govedic --- .../kernels/bench_per_token_quant_fp8.py | 37 +++++--- tests/compile/test_fusion.py | 4 +- tests/compile/test_fusion_attn.py | 2 + tests/compile/test_silu_mul_quant_fusion.py | 36 ++++++-- .../model_executor/layers/fp8_quantization.py | 85 +++++++++++++++++++ vllm/model_executor/layers/fused_moe/utils.py | 1 + .../schemes/compressed_tensors_24.py | 23 +++-- .../schemes/compressed_tensors_w8a8_fp8.py | 7 +- .../layers/quantization/fbgemm_fp8.py | 5 +- .../model_executor/layers/quantization/fp8.py | 13 ++- .../layers/quantization/modelopt.py | 2 +- .../layers/quantization/ptpc_fp8.py | 7 +- .../quark/schemes/quark_w8a8_fp8.py | 15 ++-- .../layers/quantization/utils/w8a8_utils.py | 65 +++++++------- 14 files changed, 226 insertions(+), 76 deletions(-) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index d978a685fca1..a883281dfc37 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -1,29 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 import itertools +from typing import Callable import torch from vllm import _custom_ops as ops -from vllm.platforms import current_platform +from vllm.compilation.fusion import GroupShape +from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fp8_quantization import QuantFP8 from vllm.triton_utils import triton -def torch_per_token_quant_fp8( - input: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - return ops.dynamic_per_token_quant_fp8(input) +# TODO(luka): use standalone_compile utility +def with_dyn_arg(callable: Callable, arg_index: int, dim_index: int): + def inner(*args): + torch._dynamo.mark_dynamic(args[arg_index], dim_index) + return callable(*args) + + return inner + + +torch._dynamo.config.recompile_limit = 8888 +compilation_config = CompilationConfig(custom_ops=["none"]) +with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): + torch_per_token_quant_fp8 = torch.compile( + QuantFP8(False, GroupShape.PER_TOKEN), + fullgraph=True, + dynamic=False, # recompile for different shapes + ) + + # First dim is explicitly dynamic to simulate vLLM usage + torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) def cuda_per_token_quant_fp8( input: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - scale = torch.empty((input.shape[0], 1), device=input.device, dtype=torch.float32) - # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz - out_dtype: torch.dtype = current_platform.fp8_dtype() - output = torch.empty(input.shape, device=input.device, dtype=out_dtype) - scale_ub = None - torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) - return output, scale + return ops.scaled_fp8_quant(input) def calculate_diff(batch_size: int, seq_len: int): diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 040fd176fec1..7f9b94cfb3e3 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -44,7 +44,9 @@ def __init__(self, hidden_size: int, eps: float, static: bool, ] self.fp8_linear = Fp8LinearOp( cutlass_fp8_supported=cutlass_fp8_enabled, - use_per_token_if_dynamic=True) + act_quant_static=static, + act_quant_group_shape=group_shape, + ) def forward(self, x): resid = torch.sqrt(x) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 37ec753bbc9e..70750eb9ac4e 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # DYNAMO_ONCE does not properly propagate shapes. level=CompilationLevel.DYNAMO_AS_IS, backend="tests.compile.test_fusion_attn.backend_unfused", + custom_ops=["+quant_fp8"], ) vllm_config = VllmConfig(compilation_config=compile_config) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) @@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # DYNAMO_ONCE does not properly propagate shapes. level=CompilationLevel.DYNAMO_AS_IS, backend="tests.compile.test_fusion_attn.backend", + custom_ops=["+quant_fp8"], ) vllm_config = VllmConfig(compilation_config=compile_config) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index df36b86abdbe..c81ef8271b8f 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -4,33 +4,55 @@ import torch import vllm.envs as envs -from vllm._custom_ops import scaled_fp8_quant from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass +from vllm.compilation.fusion import GroupShape from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe +from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_FP8_SUPPORTED, Fp8LinearOp) +from vllm.platforms import current_platform from .backend import TestBackend class TestModel(torch.nn.Module): - def __init__(self, *args, **kwargs): + def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, + **kwargs): super().__init__(*args, **kwargs) self.silu_and_mul = SiluAndMul() + self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) + self.w = (torch.rand( + hidden_size, + hidden_size).to(dtype=current_platform.fp8_dtype()).t()) + + self.fp8_linear = Fp8LinearOp( + cutlass_fp8_supported=cutlass_fp8_enabled, + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) + def forward(self, x): y = self.silu_and_mul(x) - x2 = scaled_fp8_quant(y, self.scale) + x2 = self.fp8_linear.apply(y, + self.w, + self.wscale, + input_scale=self.wscale) return x2 @pytest.mark.parametrize("num_tokens", [256]) @pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("cutlass_fp8_enabled", + [True, False] if CUTLASS_FP8_SUPPORTED else [False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, + cutlass_fp8_enabled): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -40,11 +62,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): pass_config=PassConfig(enable_fusion=True, enable_noop=True)) fusion_pass = ActivationQuantFusionPass(config) - backend = TestBackend(fusion_pass) - model = TestModel() + backend = TestBackend(NoOpEliminationPass(config), fusion_pass) + model = TestModel(hidden_size, cutlass_fp8_enabled) # First dimension dynamic - x = torch.rand(num_tokens, hidden_size) + x = torch.rand(num_tokens, hidden_size * 2) torch._dynamo.mark_dynamic(x, 0) result = model(x) diff --git a/vllm/model_executor/layers/fp8_quantization.py b/vllm/model_executor/layers/fp8_quantization.py index 7279e40425de..d740bd20400b 100644 --- a/vllm/model_executor/layers/fp8_quantization.py +++ b/vllm/model_executor/layers/fp8_quantization.py @@ -2,8 +2,10 @@ from typing import Optional import torch +import torch.nn.functional as F from vllm import _custom_ops as ops +from vllm.compilation.fusion import GroupShape from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform @@ -94,3 +96,86 @@ def forward_cuda(self, scale: Optional[torch.Tensor] = None ) -> tuple[torch.Tensor, torch.Tensor]: return ops.scaled_fp8_quant(x, scale=scale) + + +@CustomOp.register("quant_fp8") +class QuantFP8(CustomOp): + """ + Quantize input tensor to per-tensor or per-token FP8. + This CustomOp supports both static and dynamic quantization. + """ + + def __init__(self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None): + """ + + :param static: static or dynamic quantization + :param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR) + :param num_token_padding: Pad the token dimension of output to this size + """ + super().__init__() + self.num_token_padding = num_token_padding + assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} + assert not static or group_shape == GroupShape.PER_TENSOR, \ + "Only per-tensor scales supported for static quantization." + self.static = static + self.group_shape = group_shape + self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN + + def forward_cuda( + self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_ub: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert (scale is not None) == self.static + assert scale_ub is None or (not self.static and self.group_shape + == GroupShape.PER_TOKEN + and scale_ub.size() == (1, )) + + return ops.scaled_fp8_quant( + x, + scale, + num_token_padding=self.num_token_padding, + scale_ub=scale_ub, + use_per_token_if_dynamic=self.use_per_token_if_dynamic) + + def forward_native( + self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_ub: Optional[torch.Tensor] = None, + ): + assert (scale is not None) == self.static + assert scale_ub is None or (not self.static and self.group_shape + == GroupShape.PER_TOKEN + and scale_ub.size() == (1, )) + + if scale is None: + if self.group_shape == GroupShape.PER_TOKEN: + x_max, _ = x.abs().max(dim=-1) + x_max = x_max.unsqueeze(-1).to(torch.float32) + if scale_ub is not None: + x_max = x_max.clamp(max=scale_ub) + else: + x_max = x.abs().max().to(torch.float32) + + scale = x_max / _FP8_MAX + scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR) + + # Even for dynamic per-token scales, + # reciprocal performs slightly better than division + out = x.to(torch.float32) * scale.reciprocal() + out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + + # This currently generates an extra Triton kernel in compilation. + # Fortunately, we don't use padding if compiling. + # TODO(luka): benchmark torch._scaled_mm to hopefully remove padding + # in general. + if self.num_token_padding is not None: + padding = max(self.num_token_padding - out.size(0), 0) + out = F.pad(out, (0, 0, 0, padding), "constant", 0.0) + + return out, scale diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 1eb949790060..25ccc820acb9 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -37,6 +37,7 @@ def _fp8_quantize( is provided, the output will be blocked. """ if block_shape is None: + # TODO(luka): use QuantFP8 custom op A, A_scale = ops.scaled_fp8_quant( A, A_scale, use_per_token_if_dynamic=per_act_token) else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 30ed55aee04f..d65ea6db8f19 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -11,6 +11,8 @@ from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops +from vllm.compilation.fusion import GroupShape +from vllm.model_executor.layers.fp8_quantization import QuantFP8 from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( @@ -24,6 +26,8 @@ __all__ = ["CompressedTensors24"] +from vllm.platforms import current_platform + class CompressedTensors24(CompressedTensorsScheme): @@ -45,6 +49,12 @@ def __init__( and self.model_compressor.sparsity_config.format == CompressionFormat.sparse_24_bitmask.value) + FP8_DTYPE = current_platform.fp8_dtype() + if quantized and self._get_quant_dtype() == FP8_DTYPE: + static = not input_quant.dynamic + g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN + self.quant_fp8 = QuantFP8(static, g_shape) + @classmethod def get_min_capability(cls) -> int: # Only cutlass 3.x kernels are implemented so far @@ -232,9 +242,7 @@ def apply_weights( :return: The output tensor of the layer """ if self.quantized: - scale = None - if hasattr(layer, "input_scale"): - scale = layer.input_scale + scale = getattr(layer, 'input_scale', None) if self.weights_dtype == torch.int8: ops_output = ops.scaled_int8_quant(x, scale=scale) @@ -242,11 +250,7 @@ def apply_weights( input_scale = ops_output[1] else: assert self.weights_dtype == torch.float8_e4m3fn - if scale is not None: - q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale) - else: - q_input, input_scale = ops.scaled_fp8_quant( - x, use_per_token_if_dynamic=True) + q_input, input_scale = self.quant_fp8(x, scale=scale) else: # Not quantized, nothing to do with the input_scales, use as is @@ -269,7 +273,10 @@ def apply_weights( def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: if not self.quantized: return params_dtype + return self._get_quant_dtype() + def _get_quant_dtype(self) -> torch.dtype: + assert self.quantized assert self.weight_quant is not None assert self.input_quant is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 1e61e058cb84..e98084834f67 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -7,6 +7,7 @@ from compressed_tensors.quantization import QuantizationStrategy from torch.nn import Parameter +from vllm.compilation.fusion import GroupShape from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -26,7 +27,11 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.act_q_group_shape = GroupShape.PER_TENSOR \ + if is_static_input_scheme else GroupShape.PER_TOKEN + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_q_group_shape) @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 3e465ee2cdd2..8d8b9beb888d 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -7,6 +7,7 @@ from torch.nn import Module from torch.nn.parameter import Parameter +from vllm.compilation.fusion import GroupShape from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -37,7 +38,6 @@ def __init__(self, ignore_list: list[str], input_scale_ub: float): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = not current_platform.has_device_capability(89) - self.fp8_linear = Fp8LinearOp() @classmethod def get_name(cls) -> QuantizationMethods: @@ -76,7 +76,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) self.out_dtype = torch.get_default_dtype() def create_weights( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5a1a427d7d72..5b7311b893b7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,6 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.compilation.fusion import GroupShape from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -199,9 +200,17 @@ def __init__(self, quant_config: Fp8Config): and current_platform.is_fp8_fnuz()) self.block_quant = self.quant_config.weight_block_size is not None + self.act_q_static = self.quant_config.activation_scheme == "static" + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN + else: + self.act_q_group_shape = GroupShape.PER_TENSOR + self.fp8_linear = Fp8LinearOp( - # Default to using per_token quantization if cutlass is supported - use_per_token_if_dynamic=cutlass_fp8_supported()) + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape, + cutlass_fp8_supported=cutlass_fp8_supported()) def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 2295c0e5fe9f..ab3e44e8f113 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -102,7 +102,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp() + self.fp8_linear = Fp8LinearOp(act_quant_static=True) def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 32ba1055f9c8..caf872d1ec57 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.compilation.fusion import GroupShape from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -95,8 +96,10 @@ def __init__(self, quant_config: PTPCFp8Config): super().__init__(quant_config=quant_config) # Force weight quantization self.quant_config.is_checkpoint_fp8_serialized = False - self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False, - use_per_token_if_dynamic=True) + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, + cutlass_fp8_supported=False, + act_quant_group_shape=GroupShape.PER_TOKEN) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index c7bc98184d0e..2a24f05d715e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -6,6 +6,7 @@ import torch from torch.nn import Parameter +from vllm.compilation.fusion import GroupShape from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) @@ -28,10 +29,14 @@ def __init__(self, weight_config: dict[str, Any], self.is_static_input_scheme = not cast( bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) - self.use_per_token_if_dynamic = (not self.is_static_input_scheme \ - and self.input_qscheme == "per_channel") + + per_token = (not self.is_static_input_scheme + and self.input_qscheme == "per_channel") + self.act_quant_group_shape = GroupShape.PER_TOKEN \ + if per_token else GroupShape.PER_TENSOR self.fp8_linear = Fp8LinearOp( - use_per_token_if_dynamic=self.use_per_token_if_dynamic) + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_quant_group_shape) self.out_dtype = torch.get_default_dtype() @classmethod @@ -44,7 +49,7 @@ def process_weights_after_loading(self, layer) -> None: # tensor scales (thus N scales being passed to the kernel), # requantize so we can always run per tensor if self.weight_qscheme == "per_tensor": - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, @@ -82,7 +87,7 @@ def process_weights_after_loading(self, layer) -> None: requires_grad=False) else: weight_scale = layer.weight_scale.data - if self.use_per_token_if_dynamic: + if self.act_quant_group_shape == GroupShape.PER_TOKEN: weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index adc67aa64952..cac7fce875c2 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -7,7 +7,9 @@ from vllm import _custom_ops as ops from vllm import envs +from vllm.compilation.fusion import GroupShape from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.model_executor.layers.fp8_quantization import QuantFP8 from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting @@ -271,20 +273,21 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def dispatch_w8a8_scaled_mm( cutlass_fp8_supported: bool, per_tensor_weights: bool, - per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool] -) -> Callable[..., torch.Tensor]: + per_tensor_activations: bool) -> Callable[..., torch.Tensor]: + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: return cutlass_w8a8_scaled_mm if per_tensor_weights and per_tensor_activations: if current_platform.is_rocm(): return rocm_per_tensor_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm - # torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token - if (use_per_token_if_dynamic and not per_tensor_weights - and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + # If torch.scaled_mm supports per-channel (weights) per-token (inputs) + if not per_tensor_weights and not per_tensor_activations \ + and USE_ROWWISE_TORCH_SCALED_MM: return torch_per_token_w8a8_scaled_mm + # Normally, torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token return torch_channelwise_w8a8_scaled_mm @@ -299,11 +302,11 @@ class Fp8LinearOp: """ def __init__(self, + act_quant_static: bool, cutlass_fp8_supported: bool = cutlass_fp8_supported(), - use_per_token_if_dynamic: bool = False, + act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, pad_output: Optional[bool] = None): self.cutlass_fp8_supported = cutlass_fp8_supported - self.use_per_token_if_dynamic = use_per_token_if_dynamic # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. @@ -312,9 +315,16 @@ def __init__(self, # as it breaks with dynamic shapes. if pad_output is None: config = get_current_vllm_config().compilation_config - pad_output = config.level < CompilationLevel.PIECEWISE - self.output_padding = 17 if ( - pad_output and not current_platform.is_rocm()) else None + pad_output = config.level < CompilationLevel.PIECEWISE and \ + not cutlass_fp8_supported and \ + not current_platform.is_rocm() + + self.output_padding = 17 if pad_output else None + self.act_quant_static = act_quant_static + self.act_quant_group_shape = act_quant_group_shape + self.quant_fp8 = QuantFP8(static=act_quant_static, + group_shape=act_quant_group_shape, + num_token_padding=self.output_padding) def apply( self, @@ -325,8 +335,6 @@ def apply( input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - # TODO(luka) remove this parameter in favor of __init__ - use_per_token_if_dynamic: Optional[bool] = None ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. @@ -336,40 +344,27 @@ def apply( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[1]] - # TODO(luka) this is here because currently MLA only decides this - # during the forward method instead of in __init__. - if use_per_token_if_dynamic is None: - use_per_token_if_dynamic = self.use_per_token_if_dynamic - if out_dtype is None: out_dtype = input.dtype - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if self.cutlass_fp8_supported: - assert input.dtype != current_platform.fp8_dtype( - ), "FP8 input to cutlass is not currently implemented" - qinput, x_scale = ops.scaled_fp8_quant( + # If input not quantized + # TODO(luka) remove this path if not used anymore + if input.dtype != current_platform.fp8_dtype(): + qinput, x_scale = self.quant_fp8( input_2d, input_scale, - scale_ub=input_scale_ub, - use_per_token_if_dynamic=use_per_token_if_dynamic) + input_scale_ub, + ) else: - if input.dtype != current_platform.fp8_dtype(): - # Maybe apply padding to output, see comment in __init__ - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=self.output_padding, - use_per_token_if_dynamic=use_per_token_if_dynamic) - else: - qinput, x_scale = input_2d, input_scale + qinput, x_scale = input_2d, input_scale per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) + # TODO(luka) do this dispatch during init (after ScaledMM refactor) w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( self.cutlass_fp8_supported, per_tensor_weights, - per_tensor_activations, use_per_token_if_dynamic) + per_tensor_activations) return w8a8_scaled_mm_func(qinput=qinput, weight=weight, From 4d4e31ae95efc436a6ef74b45a0566d22db7b401 Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Tue, 8 Jul 2025 14:21:34 -0400 Subject: [PATCH 04/13] Move GroupShape to quant_utils.py Signed-off-by: Luka Govedic --- .../kernels/bench_per_token_quant_fp8.py | 2 +- tests/compile/test_silu_mul_quant_fusion.py | 3 +- vllm/compilation/fusion.py | 25 ++----------- .../model_executor/layers/fp8_quantization.py | 3 +- .../schemes/compressed_tensors_24.py | 3 +- .../schemes/compressed_tensors_w8a8_fp8.py | 3 +- .../layers/quantization/fbgemm_fp8.py | 3 +- .../model_executor/layers/quantization/fp8.py | 3 +- .../layers/quantization/ptpc_fp8.py | 3 +- .../quark/schemes/quark_w8a8_fp8.py | 3 +- .../layers/quantization/utils/quant_utils.py | 35 +++++++++++++++---- .../layers/quantization/utils/w8a8_utils.py | 3 +- 12 files changed, 47 insertions(+), 42 deletions(-) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index a883281dfc37..6b0c1463664a 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -5,9 +5,9 @@ import torch from vllm import _custom_ops as ops -from vllm.compilation.fusion import GroupShape from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fp8_quantization import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index c81ef8271b8f..5351a3cf35ba 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -5,11 +5,12 @@ import vllm.envs as envs from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass -from vllm.compilation.fusion import GroupShape from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_FP8_SUPPORTED, Fp8LinearOp) from vllm.platforms import current_platform diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 951a2861e3a4..3dec939c2835 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, ClassVar, NamedTuple, Optional +from typing import Callable, NamedTuple, Optional import torch import torch._inductor.pattern_matcher as pm @@ -11,6 +11,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe @@ -33,27 +35,6 @@ def empty_fp32(*args, **kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -# Use proxy as NamedTuple direct subclasses cannot have static members -class _GroupShape(NamedTuple): - row: int - col: int - - -class GroupShape(_GroupShape): - """ - This class describes the quantization group shape. - It includes static members for common shapes (per-tensor, per-token). - """ - - # Aliases for common quantization group shapes - PER_TENSOR: ClassVar['GroupShape'] - PER_TOKEN: ClassVar['GroupShape'] - - -GroupShape.PER_TENSOR = GroupShape(-1, -1) -GroupShape.PER_TOKEN = GroupShape(1, -1) - - class QuantKey(NamedTuple): """ Named tuple for identifying the type of quantization. diff --git a/vllm/model_executor/layers/fp8_quantization.py b/vllm/model_executor/layers/fp8_quantization.py index d740bd20400b..3e54f175f4b1 100644 --- a/vllm/model_executor/layers/fp8_quantization.py +++ b/vllm/model_executor/layers/fp8_quantization.py @@ -5,8 +5,9 @@ import torch.nn.functional as F from vllm import _custom_ops as ops -from vllm.compilation.fusion import GroupShape from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform # Using the default value (240.0) from pytorch will cause accuracy diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index d65ea6db8f19..69d22ff915a4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -11,12 +11,13 @@ from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops -from vllm.compilation.fusion import GroupShape from vllm.model_executor.layers.fp8_quantization import QuantFP8 from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, sparse_cutlass_supported) from vllm.model_executor.parameter import (BasevLLMParameter, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index e98084834f67..d984e89d9e02 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -7,9 +7,10 @@ from compressed_tensors.quantization import QuantizationStrategy from torch.nn import Parameter -from vllm.compilation.fusion import GroupShape from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 8d8b9beb888d..b2cab7d4614a 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -7,7 +7,6 @@ from torch.nn import Module from torch.nn.parameter import Parameter -from vllm.compilation.fusion import GroupShape from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -17,7 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5b7311b893b7..33dcecd252e0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,7 +11,6 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.compilation.fusion import GroupShape from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -28,7 +27,7 @@ apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index caf872d1ec57..d11cba2caba8 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -7,7 +7,6 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.compilation.fusion import GroupShape from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -18,7 +17,7 @@ Fp8KVCacheMethod, Fp8LinearMethod) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 2a24f05d715e..2cb35249f49e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -6,8 +6,9 @@ import torch from torch.nn import Parameter -from vllm.compilation.fusion import GroupShape from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d6b96774b4e8..54361a2323c2 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -3,7 +3,7 @@ """This file is used for /tests and /benchmarks""" from collections.abc import Mapping from types import MappingProxyType -from typing import Optional +from typing import ClassVar, NamedTuple, Optional import numpy import torch @@ -12,13 +12,30 @@ MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.scalar_type import ScalarType, scalar_types -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# Use proxy as NamedTuple direct subclasses cannot have static members +class _GroupShape(NamedTuple): + row: int + col: int + + +class GroupShape(_GroupShape): + """ + This class describes the quantization group shape. + It includes static members for common shapes (per-tensor, per-token). + """ + + # Aliases for common quantization group shapes + PER_TENSOR: ClassVar['GroupShape'] + PER_TOKEN: ClassVar['GroupShape'] + + +GroupShape.PER_TENSOR = GroupShape(-1, -1) +GroupShape.PER_TOKEN = GroupShape(1, -1) # Normalize the group_shape to the full extent for any dims that are -1 -def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int, - int]): +def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], group_shape[1] if group_shape[1] > 0 else x.shape[-1]) @@ -58,7 +75,7 @@ def group_broadcast(t, shape): # (i.e. per-token-per-group) def scaled_quantize( x: torch.Tensor, - group_shape: tuple[int, int], + group_shape: GroupShape, quant_dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: group_shape = _normalize_quant_group_shape(x, group_shape) @@ -99,7 +116,7 @@ def scaled_quantize( def scaled_dequantize( x_q: torch.Tensor, x_s: torch.Tensor, - group_shape: Optional[tuple[int, int]] = None, + group_shape: Optional[GroupShape] = None, out_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: if group_shape is not None: @@ -332,6 +349,10 @@ def reshape_w(w): ) +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, group_size: int, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index cac7fce875c2..c7f57b4bd580 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -7,9 +7,10 @@ from vllm import _custom_ops as ops from vllm import envs -from vllm.compilation.fusion import GroupShape from vllm.config import CompilationLevel, get_current_vllm_config from vllm.model_executor.layers.fp8_quantization import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting From 2f32e8f53b63865a4aee386eb1e2a06cd7192a9d Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Tue, 8 Jul 2025 17:50:55 -0400 Subject: [PATCH 05/13] Remove old CustomOp classes, fix pre-commit Signed-off-by: Luka Govedic --- .../model_executor/layers/fp8_quantization.py | 81 +------------------ 1 file changed, 1 insertion(+), 80 deletions(-) diff --git a/vllm/model_executor/layers/fp8_quantization.py b/vllm/model_executor/layers/fp8_quantization.py index 3e54f175f4b1..8063f04aac1d 100644 --- a/vllm/model_executor/layers/fp8_quantization.py +++ b/vllm/model_executor/layers/fp8_quantization.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional import torch @@ -19,86 +20,6 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) -@CustomOp.register("quant_fp8_per_token") -class QuantFP8PerToken(CustomOp): - """ - Quantize input tensor to dynamic per-token FP8 and return quantized - tensor and scale. - - Args: - x: The input tensor to be quantized to FP8 - scale_ub: Optional upper bound for scaling factor - - Returns: - tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - - def forward_native( - self, - x: torch.Tensor, - scale_ub: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - x_token_max, _ = x.abs().max(dim=-1) - x_token_max = x_token_max.to(torch.float32) - if scale_ub is not None: - x_token_max = x_token_max.clamp(max=scale_ub) - scales = (x_token_max / _FP8_MAX).unsqueeze(-1) - scales = scales.clamp(min=_FP8_MIN_SCALING_FACTOR) - - out = x.to(torch.float32) / scales - out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) - return out, scales - - def forward_cuda( - self, - x: torch.Tensor, - scale_ub: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - return ops.scaled_fp8_quant(x, - scale_ub=scale_ub, - use_per_token_if_dynamic=True) - - -@CustomOp.register("quant_fp8_per_tensor") -class QuantFP8PerTensor(CustomOp): - """ - Quantize input tensor to per-tensor FP8 and return quantized tensor and - scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - - Returns: - tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - - def forward_native(self, - x: torch.Tensor, - scale: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - if scale is None: - scale = torch.zeros(1, device=x.device, dtype=torch.float32) - x_max = x.abs().max().to(torch.float32) - scale = x_max / _FP8_MAX - - out = (x.to(torch.float32) * scale.reciprocal()).clamp( - _FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) - return out, scale.view((1, )) - - def forward_cuda(self, - x: torch.Tensor, - scale: Optional[torch.Tensor] = None - ) -> tuple[torch.Tensor, torch.Tensor]: - return ops.scaled_fp8_quant(x, scale=scale) - - @CustomOp.register("quant_fp8") class QuantFP8(CustomOp): """ From 80405f56a612e23600afd535a365a27ae22304e9 Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Tue, 8 Jul 2025 17:59:53 -0400 Subject: [PATCH 06/13] fix pre-commit try 2 Signed-off-by: Luka Govedic --- benchmarks/kernels/bench_per_token_quant_fp8.py | 1 + .../compressed_tensors/schemes/compressed_tensors_24.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 6b0c1463664a..42cc914a5eb0 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from typing import Callable diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 69d22ff915a4..59ad63176b80 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -50,8 +50,8 @@ def __init__( and self.model_compressor.sparsity_config.format == CompressionFormat.sparse_24_bitmask.value) - FP8_DTYPE = current_platform.fp8_dtype() - if quantized and self._get_quant_dtype() == FP8_DTYPE: + if quantized and input_quant is not None and \ + self._get_quant_dtype() == current_platform.fp8_dtype(): static = not input_quant.dynamic g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN self.quant_fp8 = QuantFP8(static, g_shape) From 5650acbbd3149a2fb283bdbf5eb50f887160965d Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Wed, 9 Jul 2025 17:07:06 -0400 Subject: [PATCH 07/13] gemini feedback Signed-off-by: Luka Govedic --- benchmarks/kernels/bench_per_token_quant_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 42cc914a5eb0..00273c5aec26 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -13,10 +13,10 @@ # TODO(luka): use standalone_compile utility -def with_dyn_arg(callable: Callable, arg_index: int, dim_index: int): +def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int): def inner(*args): torch._dynamo.mark_dynamic(args[arg_index], dim_index) - return callable(*args) + return fn(*args) return inner From 6abccfe29fc49750ae663a086d3cf404f72373d4 Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Wed, 9 Jul 2025 17:43:52 -0400 Subject: [PATCH 08/13] add issue for MoE Signed-off-by: Luka Govedic --- vllm/model_executor/layers/fused_moe/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 25ccc820acb9..a3612afed0e7 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -38,6 +38,7 @@ def _fp8_quantize( """ if block_shape is None: # TODO(luka): use QuantFP8 custom op + # https://github.com/vllm-project/vllm/issues/20711 A, A_scale = ops.scaled_fp8_quant( A, A_scale, use_per_token_if_dynamic=per_act_token) else: From 49283d4cf9f2ba182673aeb9faceadb077c6a7d8 Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Wed, 9 Jul 2025 17:49:40 -0400 Subject: [PATCH 09/13] move file Signed-off-by: Luka Govedic --- benchmarks/kernels/bench_per_token_quant_fp8.py | 2 +- .../compressed_tensors/schemes/compressed_tensors_24.py | 2 +- .../{fp8_quantization.py => quantization/input_quant_fp8.py} | 0 vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename vllm/model_executor/layers/{fp8_quantization.py => quantization/input_quant_fp8.py} (100%) diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 00273c5aec26..923d678f1f2d 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fp8_quantization import QuantFP8 +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 59ad63176b80..168b221a9cfe 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -11,11 +11,11 @@ from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops -from vllm.model_executor.layers.fp8_quantization import QuantFP8 from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( diff --git a/vllm/model_executor/layers/fp8_quantization.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py similarity index 100% rename from vllm/model_executor/layers/fp8_quantization.py rename to vllm/model_executor/layers/quantization/input_quant_fp8.py diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index c7f57b4bd580..47bb45793281 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -8,7 +8,7 @@ from vllm import _custom_ops as ops from vllm import envs from vllm.config import CompilationLevel, get_current_vllm_config -from vllm.model_executor.layers.fp8_quantization import QuantFP8 +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform From dc82a0bbbee4106b81733f9de1537973de4fd84b Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Wed, 9 Jul 2025 17:54:21 -0400 Subject: [PATCH 10/13] PR feedback Signed-off-by: Luka Govedic --- vllm/model_executor/layers/quantization/modelopt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index ab3e44e8f113..0a4e36f19bf8 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -22,7 +22,7 @@ apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, requantize_with_max_scale) from vllm.model_executor.parameter import (ModelWeightParameter, @@ -102,7 +102,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp(act_quant_static=True) + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) def create_weights( self, From dc97cd021868f969338f364e0b065da77bade54a Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Wed, 9 Jul 2025 18:11:03 -0400 Subject: [PATCH 11/13] Use GroupShape in fusion Signed-off-by: Luka Govedic --- vllm/attention/backends/abstract.py | 6 ++++-- vllm/attention/backends/rocm_flash_attn.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 990ea054f338..05c098a58a0d 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,6 +9,8 @@ import torch +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: @@ -289,7 +291,7 @@ def forward( raise NotImplementedError def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: tuple[int, int]): + group_shape: GroupShape): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization @@ -298,7 +300,7 @@ def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, TODO(luka) merge parameters into QuantDescriptor :param dtype: quantized dtype :param static: static or dynamic quantization - :param group_shape: quant group shape. (-1, -1) for per-tensor. + :param group_shape: quant group shape. :return: is fusion supported for this type of quantization """ return False diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1e2c21f4e69d..0b7783758dda 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -19,6 +19,8 @@ PagedAttentionMetadata) from vllm.config import get_current_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform from vllm.platforms.rocm import use_rocm_custom_paged_attention @@ -598,10 +600,10 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim)) def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: tuple[int, int]): + group_shape: GroupShape): if self.use_triton_flash_attn: return dtype == current_platform.fp8_dtype( - ) and static and group_shape == (-1, -1) # per-tensor + ) and static and group_shape == GroupShape.PER_TENSOR # Only supported in the Triton backend return False From 69cc866b5b0ac9141a01a9812e188a9b7b4899ca Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Thu, 10 Jul 2025 14:37:50 -0400 Subject: [PATCH 12/13] Add quant_fp8 to fix fusion test Signed-off-by: Luka Govedic --- tests/compile/test_fusion.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 7f9b94cfb3e3..4a3820e20fd8 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -93,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, maybe_create_device_identity() # needed for certain non-cutlass fp8 paths vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) - vllm_config.compilation_config.pass_config = \ - PassConfig(enable_fusion=True, enable_noop=True) + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + )) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) From 404a4b2f74d8eaf1aa25428faadd9f109a42bbe8 Mon Sep 17 00:00:00 2001 From: Luka Govedic Date: Thu, 10 Jul 2025 16:22:28 -0400 Subject: [PATCH 13/13] Compilation fixes Signed-off-by: Luka Govedic --- vllm/model_executor/layers/quantization/input_quant_fp8.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 8063f04aac1d..e1a9bdde9334 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -55,7 +55,7 @@ def forward_cuda( assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape == GroupShape.PER_TOKEN - and scale_ub.size() == (1, )) + and scale_ub.numel() == 1) return ops.scaled_fp8_quant( x, @@ -73,7 +73,7 @@ def forward_native( assert (scale is not None) == self.static assert scale_ub is None or (not self.static and self.group_shape == GroupShape.PER_TOKEN - and scale_ub.size() == (1, )) + and scale_ub.numel() == 1) if scale is None: if self.group_shape == GroupShape.PER_TOKEN: @@ -82,7 +82,7 @@ def forward_native( if scale_ub is not None: x_max = x_max.clamp(max=scale_ub) else: - x_max = x.abs().max().to(torch.float32) + x_max = x.abs().max().unsqueeze(-1).to(torch.float32) scale = x_max / _FP8_MAX scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR)