From 93d09f10ffb67d45119741f0e0f71afaa412e742 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 7 Mar 2025 00:50:39 +0000 Subject: [PATCH 1/3] Refactor apply_fp8_linear and apply_fp8_linear_generic into an object Signed-off-by: luka --- vllm/attention/backends/mla/common.py | 7 +- .../schemes/compressed_tensors_w8a8_fp8.py | 19 +- .../layers/quantization/fbgemm_fp8.py | 22 +- .../model_executor/layers/quantization/fp8.py | 21 +- .../layers/quantization/modelopt.py | 16 +- .../layers/quantization/ptpc_fp8.py | 18 +- .../quark/schemes/quark_w8a8_fp8.py | 18 +- .../layers/quantization/utils/fp8_utils.py | 89 +++--- .../layers/quantization/utils/w8a8_utils.py | 269 ++++++++++-------- vllm/v1/attention/backends/mla/common.py | 7 +- 10 files changed, 257 insertions(+), 229 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 109e8496fc31..4f4b70cd8f48 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -226,7 +226,7 @@ CompressedTensorsW8A8Fp8) from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) + Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) from vllm.model_executor.layers.rotary_embedding import ( @@ -1057,6 +1057,7 @@ def __init__( self.kv_b_proj = kv_b_proj self.o_proj = o_proj self.triton_fa_func = triton_attention + self.fp8_linear_generic = Fp8LinearGenericOp() # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the @@ -1071,7 +1072,7 @@ def __init__( def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if is_fp8(self.W_UV_O): - output_parallel = apply_fp8_linear_generic( + output_parallel = self.fp8_linear_generic.apply( x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, self.reqaunt_input_group_shape, self.reqaunt_weight_group_shape) @@ -1091,7 +1092,7 @@ def _v_up_proj_and_o_proj(self, x): def _q_proj_and_k_up_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if is_fp8(self.W_Q_UK): - return apply_fp8_linear_generic( + return self.fp8_linear_generic.apply( x, self.W_Q_UK, self.W_Q_UK_scales, self.reqaunt_input_group_shape, self.reqaunt_weight_group_shape).view( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 32072e9fa570..aca25c9bfa19 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -9,8 +9,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity, - normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) + Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -24,7 +24,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) @classmethod def get_min_capability(cls) -> int: @@ -140,11 +140,8 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return apply_fp8_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - cutlass_fp8_supported=self.cutlass_fp8_supported, - use_per_token_if_dynamic=True) + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 20f2c3da600d..110e4ef2e9f1 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -11,14 +11,12 @@ UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, maybe_create_device_identity, - normalize_e4m3fn_to_e4m3fnuz) + Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter) from vllm.platforms import current_platform @@ -37,6 +35,7 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = not current_platform.has_device_capability(89) + self.fp8_linear = Fp8LinearOp() @classmethod def get_name(cls) -> str: @@ -73,7 +72,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config - self.cutlass_fp8_supported = cutlass_fp8_supported() + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) def create_weights( self, @@ -159,12 +158,9 @@ def apply(self, size_k=layer.input_size_per_partition, bias=bias) - return apply_fp8_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - input_scale_ub=layer.input_scale_ub, - bias=bias, - cutlass_fp8_supported=self.cutlass_fp8_supported, - use_per_token_if_dynamic=True) + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=layer.input_scale_ub, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a705f63be4ac..3f8e0a2f9237 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, apply_fp8_linear, convert_to_channelwise, + Fp8LinearOp, all_close_1d, convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, requantize_with_max_scale) @@ -137,7 +137,6 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() # For GPUs that lack FP8 hardware support, we can leverage the Marlin @@ -153,6 +152,10 @@ def __init__(self, quant_config: Fp8Config): # Marlin doesn't support block-wise fp8 self.use_marlin = False + self.fp8_linear = Fp8LinearOp( + # Default to using per_token quantization if cutlass is supported + use_per_token_if_dynamic=cutlass_fp8_supported()) + def create_weights( self, layer: torch.nn.Module, @@ -381,15 +384,11 @@ def apply(self, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, ) - return apply_fp8_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - cutlass_fp8_supported=self.cutlass_fp8_supported, - # Default to using per_token quantization if cutlass is supported - use_per_token_if_dynamic=self.cutlass_fp8_supported) + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias) class Fp8MoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 36711a7a5098..1f8af8d678cd 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -12,7 +12,7 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale) + Fp8LinearOp, requantize_with_max_scale) from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) @@ -95,7 +95,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config): self.quant_config = quant_config - self.cutlass_fp8_supported = cutlass_fp8_supported() + self.fp8_linear = Fp8LinearOp() def create_weights( self, @@ -157,10 +157,8 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return apply_fp8_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - cutlass_fp8_supported=self.cutlass_fp8_supported) + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 1ded5389e5f4..592ffc5dad13 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear) + Fp8LinearOp) from vllm.platforms import current_platform ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -93,6 +93,8 @@ def __init__(self, quant_config: PTPCFp8Config): super().__init__(quant_config=quant_config) # Force weight quantization self.quant_config.is_checkpoint_fp8_serialized = False + self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False, + use_per_token_if_dynamic=True) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, @@ -115,11 +117,9 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return apply_fp8_linear(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - input_scale_ub=None, - bias=bias, - cutlass_fp8_supported=False, - use_per_token_if_dynamic=True) + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=None, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index c885e98a4d66..7676fbddb6b8 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,8 +7,7 @@ from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, - requantize_with_max_scale) + Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -22,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme): def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): self.qscheme = qscheme self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) @classmethod def get_min_capability(cls) -> int: @@ -132,11 +131,8 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return apply_fp8_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - cutlass_fp8_supported=self.cutlass_fp8_supported, - use_per_token_if_dynamic=True) + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7d91d2cf1c6e..9b43f51852fd 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -15,7 +15,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( _normalize_quant_group_shape, scaled_dequantize) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear) + CUTLASS_BLOCK_FP8_SUPPORTED, Fp8LinearOp, cutlass_block_fp8_supported, + cutlass_fp8_supported) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -32,6 +33,7 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz +# TODO fix ROCm->Triton custom path def apply_w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor, @@ -104,43 +106,54 @@ def apply_w8a8_block_fp8_linear_fake( # Unify the interface between `apply_w8a8_block_fp8_linear` and # `apply_fp8_linear` # NOTE(lucas): this is quite messy, we should think through this more formally -def apply_fp8_linear_generic( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_group_shape: Tuple[int, int], - weight_group_shape: Tuple[int, int], - input_scale: Optional[torch.Tensor] = None, # static scale if one - cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED, - cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, -) -> torch.Tensor: - # View input as 2D matrix for fp8 methods - input = input.view(-1, input.shape[-1]) - - weight_group_shape = _normalize_quant_group_shape(\ - weight, weight_group_shape) - input_group_shape = _normalize_quant_group_shape(input, input_group_shape) - - def is_dim_blocked(dim, shape, group_shape): - return group_shape < shape[dim] and group_shape > 1 - - if is_dim_blocked(0, weight.shape, weight_group_shape[0])\ - and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\ - input_group_shape == (1, weight_group_shape[1]): - return apply_w8a8_block_fp8_linear( - input, - weight, - list(weight_group_shape), - weight_scale, - cutlass_block_fp8_supported=cutlass_block_fp8_supported) - else: - # Despite having linear in the it doesn't conform to - # `torch.nn.functional.linear` which is defined as `input @ weight.T` - # so we explicitly transpose the weight matrix here - return apply_fp8_linear(input, weight.T, weight_scale.T, - cutlass_fp8_supported=cutlass_fp8_supported, - use_per_token_if_dynamic=\ - (input_group_shape == (1, input.shape[1]))) +# TODO(luka): unify this better +class Fp8LinearGenericOp: + + def __init__( + self, + cutlass_fp8_supported: bool = cutlass_fp8_supported(), + cutlass_block_fp8_supported: bool = cutlass_block_fp8_supported(), + ): + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported + self.fp8_linear = Fp8LinearOp( + cutlass_fp8_supported=cutlass_fp8_supported) + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_group_shape: Tuple[int, int], + weight_group_shape: Tuple[int, int], + input_scale: Optional[torch.Tensor] = None, # static scale if one + ) -> torch.Tensor: + # View input as 2D matrix for fp8 methods + input = input.view(-1, input.shape[-1]) + + weight_group_shape = _normalize_quant_group_shape( \ + weight, weight_group_shape) + input_group_shape = _normalize_quant_group_shape( + input, input_group_shape) + + def is_dim_blocked(dim, shape, group_shape): + return group_shape < shape[dim] and group_shape > 1 + + if is_dim_blocked(0, weight.shape, weight_group_shape[0])\ + and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\ + input_group_shape == (1, weight_group_shape[1]): + return apply_w8a8_block_fp8_linear( + input, + weight, + list(weight_group_shape), + weight_scale, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported) + else: + # Despite having linear in the name it doesn't conform to + # `torch.nn.functional.linear` which is defined as + # `input @ weight.T` so we explicitly transpose the weight matrix + return self.fp8_linear.apply(input, weight.T, weight_scale.T, + use_per_token_if_dynamic=\ + (input_group_shape == (1, input.shape[1]))) def input_to_float8( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8072f307763d..02ffa890ad2a 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -121,134 +121,161 @@ def maybe_create_device_identity(): TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) -def apply_fp8_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - input_scale_ub: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED, - use_per_token_if_dynamic: bool = False, -) -> torch.Tensor: - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[1]] - - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - scale_ub=input_scale_ub, - use_per_token_if_dynamic=use_per_token_if_dynamic) - - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - return output.view(*output_shape) - - # torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token - else: +# TODO(luka): follow similar pattern for marlin and block-fp8-linear +class Fp8LinearOp: + """ + This class executes a FP8 linear layer using cutlass if supported and + torch.scaled_mm otherwise. + It needs to be a class instead of a method so that config can be read + in the __init__ method, as reading config is not allowed inside forward. + """ + + def __init__(self, + cutlass_fp8_supported: bool = cutlass_fp8_supported(), + use_per_token_if_dynamic: bool = False, + pad_output: Optional[bool] = None): + self.cutlass_fp8_supported = cutlass_fp8_supported + self.use_per_token_if_dynamic = use_per_token_if_dynamic + # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. # We also don't pad when using torch.compile, # as it breaks with dynamic shapes. - config = get_current_vllm_config().compilation_config - do_pad = config.level < CompilationLevel.PIECEWISE - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=17 if do_pad else None, - use_per_token_if_dynamic=use_per_token_if_dynamic) - - per_tensor_weights = (weight_scale.numel() == 1) - per_tensor_activations = (x_scale.numel() == 1) - - if per_tensor_weights and per_tensor_activations: - # Fused GEMM_DQ - output = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, - input_2d.shape[0]).view(*output_shape) - - elif (use_per_token_if_dynamic and not per_tensor_weights - and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): - # For now validated on ROCm platform - # fp8 rowwise scaling in torch._scaled_mm is introduced in - # https://github.com/pytorch/pytorch/pull/144432 using - # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. - # For CUDA platform please validate if the - # torch._scaled_mm support rowwise scaled GEMM - # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale.t(), - bias=bias) - - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - output = output.view(*output_shape) - return output + if pad_output is None: + config = get_current_vllm_config().compilation_config + pad_output = config.level < CompilationLevel.PIECEWISE + self.output_padding = 17 if pad_output else None + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + input_scale_ub: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + # TODO(luka) remove this parameter in favor of __init__ + use_per_token_if_dynamic: Optional[bool] = None + ) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[1]] + + # TODO(luka) this is here because currently MLA only decides this + # during the forward method instead of in __init__. + if use_per_token_if_dynamic is None: + use_per_token_if_dynamic = self.use_per_token_if_dynamic + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if self.cutlass_fp8_supported: + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + scale_ub=input_scale_ub, + use_per_token_if_dynamic=use_per_token_if_dynamic) + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + return output.view(*output_shape) + + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token else: - # Fallback for channelwise case, where we use unfused DQ - # due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm(qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * weight_scale.t() - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + # Maybe apply padding to output, see comment in __init__ + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic) + + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, + input_2d.shape[0]).view(*output_shape) + + elif (use_per_token_if_dynamic and not per_tensor_weights + and not per_tensor_activations + and USE_ROWWISE_TORCH_SCALED_MM): + # For now validated on ROCm platform + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt + # and ROCm 6.3, which only exists in torch 2.7 and above. + # For CUDA platform please validate if the + # torch._scaled_mm support rowwise scaled GEMM + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.t(), + bias=bias) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) def normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0b55854de94a..886295ee895c 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -219,7 +219,7 @@ CompressedTensorsW8A8Fp8) from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) + Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) from vllm.model_executor.layers.rotary_embedding import ( @@ -633,6 +633,7 @@ def __init__( self.kv_b_proj = kv_b_proj self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() + self.fp8_linear_generic = Fp8LinearGenericOp() # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the @@ -646,7 +647,7 @@ def __init__( def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if is_fp8(self.W_UV_O): - output_parallel = apply_fp8_linear_generic( + output_parallel = self.fp8_linear_generic.apply( x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, self.reqaunt_input_group_shape, self.reqaunt_weight_group_shape) @@ -666,7 +667,7 @@ def _v_up_proj_and_o_proj(self, x): def _q_proj_and_k_up_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if is_fp8(self.W_Q_UK): - return apply_fp8_linear_generic( + return self.fp8_linear_generic.apply( x, self.W_Q_UK, self.W_Q_UK_scales, self.reqaunt_input_group_shape, self.reqaunt_weight_group_shape).view( From cfb086c01e3eb6fad80de673c1e51e74986fc7c2 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 7 Mar 2025 01:48:58 +0000 Subject: [PATCH 2/3] Fix test_fusion test Signed-off-by: luka --- tests/compile/test_fusion.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 89abc001764b..aaf027781090 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -13,7 +13,7 @@ from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity) + CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) from .backend import TestBackend @@ -34,26 +34,20 @@ def __init__(self, hidden_size: int, eps: float, static: bool, torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(2) ] + self.fp8_linear = Fp8LinearOp( + cutlass_fp8_supported=cutlass_fp8_enabled, + use_per_token_if_dynamic=True) def forward(self, x): resid = torch.sqrt(x) y = self.norm[0](x) - x2 = apply_fp8_linear(y, - self.w[0], - self.wscale[0], - self.scale[0], - use_per_token_if_dynamic=True, - cutlass_fp8_supported=self.cutlass_fp8_enabled) + x2 = self.fp8_linear.apply(y, self.w[0], self.wscale[0], self.scale[0]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = apply_fp8_linear(y2, - self.w[1], - self.wscale[1], - self.scale[1], - use_per_token_if_dynamic=True, - cutlass_fp8_supported=self.cutlass_fp8_enabled) + x3 = self.fp8_linear.apply(y2, self.w[1], self.wscale[1], + self.scale[1]) y3, resid = self.norm[2](x3, resid) # use resid here return y3 From ede98ce16240d0f9ab151497b61d4fc5aa222f85 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 7 Mar 2025 02:27:59 +0000 Subject: [PATCH 3/3] Add TODOs with issue links Signed-off-by: luka --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 5 ++++- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9b43f51852fd..62569185ef47 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -33,7 +33,8 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz -# TODO fix ROCm->Triton custom path +# TODO fix ROCm->Triton custom path: +# https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor, @@ -51,6 +52,7 @@ def apply_w8a8_block_fp8_linear( shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) if current_platform.is_rocm(): + # TODO this is never used, as cutlass_block_fp8_supported is False scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + input_2d.shape[:-1])[::-1] scale_b_shape = (weight_scale.view(-1, 1) @@ -107,6 +109,7 @@ def apply_w8a8_block_fp8_linear_fake( # `apply_fp8_linear` # NOTE(lucas): this is quite messy, we should think through this more formally # TODO(luka): unify this better +# https://github.com/vllm-project/vllm/issues/14397 class Fp8LinearGenericOp: def __init__( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 02ffa890ad2a..9de8e453354c 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -122,6 +122,7 @@ def maybe_create_device_identity(): # TODO(luka): follow similar pattern for marlin and block-fp8-linear +# https://github.com/vllm-project/vllm/issues/14397 class Fp8LinearOp: """ This class executes a FP8 linear layer using cutlass if supported and