diff --git a/auto_round/autoround.py b/auto_round/autoround.py index cc8a526b..49e3984a 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -18,6 +18,7 @@ import sys import time import traceback +from enum import Enum from typing import Any, Callable, Union import accelerate @@ -87,6 +88,12 @@ from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block +class AutoRoundFormat(str, Enum): + # Weight: FP8, per-channel, may be extended to per-tensor in future + # Activation: FP8, per-tensor + TORCH_FP8_STATIC = "torch_fp8_static" + + class AutoRound(object): """Automatic weight rounding (Signed Gradient Descent) for LLM quantization @@ -678,8 +685,8 @@ def _parse_format_to_list(self, format: str) -> list: format = "auto_round:auto_awq" elif is_nv_fp(self.data_type) or is_mx_fp(self.data_type): format = f"auto_round:{self.data_type}" - elif is_wfp8afp8(self): # staic wfp8afp8 - format = "auto_round:fp8" + elif is_static_wfp8afp8(self): # staic wfp8afp8 + format = f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}" elif self.data_type == "fp" and self.bits == 8 and self.act_bits >= 16: # woq fp8 format = "auto_round:fp8" elif self.act_bits < 16: @@ -755,10 +762,10 @@ def _check_supported_format(self, format: str) -> bool: ) format = "fake" else: - if not (format == "auto_round" or format == "auto_round:fp8"): + if not (format == "auto_round" or format == f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}"): logger.warning( f"Currently only support to export auto_round or fake format for static W{self.bits}AFP8 model," - " change format to auto_round" + f" change format {format} to auto_round" ) format = "auto_round" if self.act_group_size != 0 and not self.act_dynamic and format == "auto_round:fp8": diff --git a/auto_round/experimental/qmodules/base.py b/auto_round/experimental/qmodules/base.py new file mode 100644 index 00000000..8b7a9c13 --- /dev/null +++ b/auto_round/experimental/qmodules/base.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional, Union + +import torch + +__all__ = ["QModuleBase"] + + +class QModuleBase(torch.nn.Module): + """ + Base class used to describe the weight creation and forward pass + of different quantization schemes supported by Auto-Round. + The design is inspired by vLLM's CompressedTensorsScheme: + https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py + + """ + + def __init__(self): + super().__init__() + + @classmethod + @abstractmethod + def from_original(cls, config, original_layer: torch.nn.Module): + raise NotImplementedError + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/auto_round/experimental/qmodules/fp8_static.py b/auto_round/experimental/qmodules/fp8_static.py new file mode 100644 index 00000000..a6798f53 --- /dev/null +++ b/auto_round/experimental/qmodules/fp8_static.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Optional, Union + +import torch + +from auto_round.experimental.qmodules.base import QModuleBase +from auto_round.utils import logger + +__all__ = ["WeightFP8ActFP8StaticQuantLinear"] + + +def _quant_tensor_to_fp8_with_scale(tensor: torch.Tensor, scale: torch.Tensor): + FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max + qtensor = tensor / scale + clipped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE) + clipped_qtensor_fp8 = clipped_qtensor.to(torch.float8_e4m3fn) + return scale, clipped_qtensor_fp8 + + +class WeightFP8ActFP8StaticQuantLinear(QModuleBase): + hp_dtype = torch.bfloat16 + fp8_dtype = torch.float8_e4m3fn + + def __init__( + self, + in_features, + out_features, + weight: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, bool, None] = None, + input_scale: Optional[torch.Tensor] = None, + dtype=torch.bfloat16, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + init_weight = torch.zeros((out_features, in_features), dtype=dtype) if weight is None else weight + self.weight = torch.nn.Parameter(init_weight, requires_grad=False) + self.dtype = dtype + if bias is not None: + if isinstance(bias, bool): + bias = torch.zeros((out_features,), dtype=dtype) + self.bias = torch.nn.Parameter(bias, requires_grad=False) + else: + self.register_parameter("bias", None) + init_weight_scale = torch.empty((out_features), dtype=dtype) if weight_scale is None else weight_scale + self.register_buffer("weight_scale", init_weight_scale.to(dtype)) + + init_input_scale = torch.zeros((1), dtype=dtype) if input_scale is None else input_scale + self.register_buffer("input_scale", init_input_scale.to(dtype)) + self.pre_dequantized = False + + @classmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + # TODO: correct that config once we add fp8 op support. + logger.warning_once("FP8 ops are not yet supported. Using capability 0.") + return 0 + + def process_weights_after_loading(self, layer: torch.nn.Module): + pass + + @classmethod + def from_original(cls, config, original_layer): + """ + Create an WeightFP8ActFP8StaticQuantLinear layer from an original linear layer. + """ + device = original_layer.weight.device + with torch.device(device): + qdq_linear = cls( + in_features=original_layer.in_features, + out_features=original_layer.out_features, + bias=original_layer.bias, + ) + return qdq_linear + + def dequant_weight_online(self): + if self.pre_dequantized: + return self.weight + fp8_weight = self.weight + qdq_weight = fp8_weight.to(self.dtype) * self.weight_scale.unsqueeze(1) + return qdq_weight + + def pre_dequantize(self): + if self.pre_dequantized: + return + dequant_weight = self.dequant_weight_online() + del self.weight + del self.weight_scale + self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False) + self.pre_dequantized = True + + def qdq_input(self, bf16_input: torch.Tensor): + input_scale, input_fp8 = _quant_tensor_to_fp8_with_scale(bf16_input, self.input_scale.data) + qdq_input_bf16 = input_fp8.to(self.dtype) * input_scale + return qdq_input_bf16 + + @torch.no_grad() + def forward(self, bf16_input: torch.Tensor) -> torch.Tensor: + qdq_input = self.qdq_input(bf16_input) + qdq_weight = self.dequant_weight_online() + out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) + return out diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index f56d2fbf..ea7149b5 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -273,9 +273,14 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround return save_quantized_as_autoround(output_dir, inplace=inplace, backend="auto_round", **kwargs) + from auto_round.autoround import AutoRoundFormat ##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source - if (kwargs.get("sym") is None or kwargs.get("sym")) and ("gptq" not in backend and "awq" not in backend): + if ( + (kwargs.get("sym") is None or kwargs.get("sym")) + and ("gptq" not in backend and "awq" not in backend) + and (AutoRoundFormat.TORCH_FP8_STATIC.value not in backend) + ): backend = backend.replace("auto_round", "auto_round:auto_gptq") model = kwargs["model"] diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py index 1569e800..00f0ce64 100644 --- a/auto_round/inference/backend.py +++ b/auto_round/inference/backend.py @@ -19,7 +19,7 @@ from transformers.utils.versions import require_version import auto_round_extension.cuda.gptqmodel_marlin -from auto_round.utils import get_library_version, logger +from auto_round.utils import get_library_version, is_weight_fp8_activation_static_fp8, logger BackendInfos = {} @@ -172,6 +172,22 @@ def feature_multiply_checker_group_size( requirements=["auto-round>=0.5.1"], ) +# FP8 static quant +# Weight: FP8, per-channel, may be extended to per-tensor in future +# Activation: FP8, per-tensor + +BackendInfos["auto_round:torch_fp8_static"] = BackendInfo( + device=["xpu", "cuda", "cpu"], + packing_format="", + sym=[True], + dtype=["float32", "float16", "bfloat16"], + bits=[8], + priority=0, + feature_checks=[], + alias=["auto_round", "torch"], + requirements=["auto-round>=0.6.1.dev0"], +) + BackendInfos["auto_round:tritonv2_zp"] = BackendInfo( device=["cuda", "xpu"], sym=[True], ## asym has accuracys @@ -413,7 +429,7 @@ def check_compatible( return True -def dynamic_import_inference_linear(backend, bits, group_size, sym): +def dynamic_import_inference_linear(backend, config): """Dynamically imports and returns the appropriate QuantLinear class based on the given backend. This function dynamically loads the correct `QuantLinear` class based on the backend and quantization @@ -438,6 +454,13 @@ def dynamic_import_inference_linear(backend, bits, group_size, sym): ImportError: If required modules are missing for a backend (e.g., Intel Extension, GPTQ, auto_awq). """ + bits, group_size, sym = config["bits"], config["group_size"], config["sym"] + + if is_weight_fp8_activation_static_fp8(config): + from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear + + return WeightFP8ActFP8StaticQuantLinear + if "qbits" in backend: try: from intel_extension_for_transformers import qbits # pylint: disable=E0401 @@ -825,6 +848,7 @@ def build_pip_commands(gptq_req, other_reqs): # Instructional messages install_instructions = [] + for cmd in pip_cmds: if "intel-extension-for-pytorch" in cmd and target_device == "xpu": install_instructions.append( diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index bd6dde83..df8b52c0 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -27,6 +27,7 @@ find_backend, get_highest_priority_backend, get_layer_backend, + is_weight_fp8_activation_static_fp8, process_requirement, ) from auto_round.utils import ( @@ -61,7 +62,7 @@ def skip_not_convert_modules(model, quantization_config, layer_names, layer_conf try: # transformers new api modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert, add_default_skips=True) except: - modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert) + modules_to_not_convert = _get_modules_to_not_convert(model, modules_to_not_convert) if modules_to_not_convert: for layer_name in layer_names: if any([re.search(re.compile(n), layer_name) for n in modules_to_not_convert]): @@ -219,6 +220,7 @@ def get_layer_config(model, quantization_config): - group_size (int): Group size for weight quantization. - data_type (str, optional): Data type for quantization (default: "int"). - sym (bool): Whether to use symmetric quantization. + - act_dynamic (bool, optional): Whether to use dynamic activation quantization (default: False). - quant_block_list (list, optional): Predefined list of blocks to quantize. - to_quant_block_names (list or str, optional): Blocks to quantize (if quant_block_list is None). - extra_config (dict, optional): Per-layer overrides for quantization settings. @@ -231,13 +233,14 @@ def get_layer_config(model, quantization_config): - "group_size" (int): Group size for quantization. - "data_type" (str): Data type used for quantization. - "sym" (bool): Whether symmetric quantization is applied. + - "act_dynamic" (bool): Whether dynamic activation quantization is used. - "clip" (bool): Whether weight clipping is enabled. """ bits = quantization_config.bits group_size = quantization_config.group_size data_type = getattr(quantization_config, "data_type", "int") # Default to "int" if not specified sym = quantization_config.sym - + act_dynamic = getattr(quantization_config, "act_dynamic", False) # Determine the quantization block list quant_block_list = getattr(quantization_config, "quant_block_list", None) if quant_block_list is None: @@ -290,11 +293,11 @@ def get_layer_config(model, quantization_config): "group_size": extra_config.get(layer_name, {}).get("group_size", group_size), "data_type": extra_config.get(layer_name, {}).get("data_type", data_type), "sym": extra_config.get(layer_name, {}).get("sym", sym), + "act_dynamic": extra_config.get(layer_name, {}).get("act_dynamic", act_dynamic), "clip": extra_config.get(layer_name, {}).get("clip", False), } for layer_name in layer_names } - return layer_configs @@ -415,7 +418,7 @@ def _import_exllamav2_kernels(): def _create_quant_layer(layer, layer_backend, config, in_features, out_features): """Creates a quantized layer using the appropriate class.""" - QuantLinear = dynamic_import_inference_linear(layer_backend, config["bits"], config["group_size"], config["sym"]) + QuantLinear = dynamic_import_inference_linear(layer_backend, config) bias = layer.bias is not None # Special handling for AWQ layers @@ -437,6 +440,8 @@ def _create_quant_layer(layer, layer_backend, config, in_features, out_features) out_features=out_features, bias=bias, ) + elif is_weight_fp8_activation_static_fp8(config): + return QuantLinear.from_original(config, layer) # Default quantized layer creation try: return QuantLinear( @@ -561,7 +566,6 @@ def convert_hf_model(model: nn.Module, target_device="cpu"): backend = quantization_config.backend else: backend = "auto" - ##target_backend could be None _, backend = parse_target_device_and_backend(backend) @@ -588,7 +592,6 @@ def convert_hf_model(model: nn.Module, target_device="cpu"): backend = backend[len("auto_round:") :] used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, orig_backend) - if backend == "auto" or backend == "": best_backend = get_highest_priority_backend( quantization_config.bits, diff --git a/auto_round/utils.py b/auto_round/utils.py index c3606aaf..21363688 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -106,9 +106,17 @@ def infer_bits_by_data_type(data_type: str): return None -@lru_cache(None) -def warning_once(self, msg: str): - self.warning(msg) +@lru_cache(maxsize=None) +def warning_once(self, msg, *args, **kwargs): + """ + Log a warning message only once per unique message/arguments combination. + + Args: + msg: The warning message format string + *args: Variable positional arguments for message formatting + **kwargs: Variable keyword arguments for message formatting and logging options + """ + self.warning(msg, *args, **kwargs) class AutoRoundFormatter(logging.Formatter): @@ -2519,6 +2527,21 @@ def is_nv_fp(backend): return BackendDataType.NV_FP in backend +def _is_weight_fp8_activation_static_fp8(bit, group_size, sym, data_type, act_dynamic): + return bit == 8 and group_size == -1 and sym and data_type == "fp8" and not act_dynamic + + +def is_weight_fp8_activation_static_fp8(config): + bits, group_size, sym, data_type, act_dynamic = ( + config["bits"], + config["group_size"], + config["sym"], + config["data_type"], + config["act_dynamic"], + ) + return _is_weight_fp8_activation_static_fp8(bits, group_size, sym, data_type, act_dynamic) + + def is_wfp8afp8(ar): if ("fp8" in ar.act_data_type or ("fp" in ar.act_data_type and ar.act_bits == 8)) and ( "fp8" in ar.data_type or ("fp" in ar.data_type and ar.bits == 8) diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index 68942ac7..d648fd72 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -230,6 +230,33 @@ def test_static_afp8_export(self, static_kv_dtype): self.assertIn("model.decoder.layers.8.self_attn.k_proj.weight_scale", f.keys()) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.input_scale").shape, torch.Size([1])) self.assertEqual(f.get_tensor("model.decoder.layers.5.self_attn.v_proj.weight").dtype, torch.float8_e4m3fn) + if static_kv_dtype is None: + with torch.no_grad(): + import transformers + + model = transformers.AutoModelForCausalLM.from_pretrained( + quantized_model_path, + torch_dtype="auto", + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + model.eval() + assert ( + model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__ + == "WeightFP8ActFP8StaticQuantLinear" + ), f"Expected WeightFP8ActFP8StaticQuantLinear, got {model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__}" + tokenizer = transformers.AutoTokenizer.from_pretrained(quantized_model_path) + prompt = "AI is " + encode = tokenizer.encode(prompt, return_tensors="pt") + with torch.no_grad(): + output_tokens = model.generate( + encode, + max_length=10, + ) + output = tokenizer.decode(output_tokens[0], skip_special_tokens=True) + print(f"Prompt: {prompt}") + print(f"Output: {output}") + assert output is not None, "Output should not be None" if static_kv_dtype == "fp8": self.assertIn("model.decoder.layers.8.self_attn.k_scale", f.keys())