Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bb94782
load w8a8
yiliu30 Aug 11, 2025
9bef826
refactor
yiliu30 Aug 12, 2025
b30a126
add ut
yiliu30 Aug 12, 2025
eaad3a6
remove example
yiliu30 Aug 12, 2025
c411ca5
fix typo
yiliu30 Aug 12, 2025
9802313
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 12, 2025
6597d5c
Update auto_round/export/export_to_autoround/export_to_fp8_woq.py
yiliu30 Aug 13, 2025
9b0f32f
Update export_to_fp8_woq.py
yiliu30 Aug 13, 2025
c32daa6
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 13, 2025
c136339
megre main
yiliu30 Aug 24, 2025
5ebca24
update shape
yiliu30 Aug 24, 2025
03cb217
refactor
yiliu30 Aug 26, 2025
e7280f6
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 26, 2025
66388e5
tmp add bk
yiliu30 Aug 26, 2025
17ddd2d
refactor code
yiliu30 Aug 27, 2025
808449d
refine code
yiliu30 Aug 27, 2025
f74ed6f
fix device list
yiliu30 Aug 27, 2025
632cf8a
fix
yiliu30 Aug 27, 2025
5b8b29d
refactor code
yiliu30 Aug 27, 2025
57b4c19
fix
yiliu30 Aug 27, 2025
bdf5f3e
update
yiliu30 Aug 27, 2025
ce3384f
fix ut
yiliu30 Aug 27, 2025
7cea90e
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 28, 2025
22d11de
correct
yiliu30 Aug 28, 2025
9082613
clean
yiliu30 Aug 28, 2025
6503355
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 Aug 28, 2025
b687633
Merge branch 'main' into wfp8-afp8
yiliu30 Aug 28, 2025
2202856
fix shape
yiliu30 Aug 28, 2025
10f5753
Merge branch 'wfp8-afp8' of https://github.com/intel/auto-round into …
yiliu30 Aug 29, 2025
cc42e47
merge with main
yiliu30 Aug 29, 2025
d0b99a8
fix check
yiliu30 Aug 29, 2025
31845d0
clean code
yiliu30 Aug 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import time
import traceback
from enum import Enum
from typing import Any, Callable, Union

import accelerate
Expand Down Expand Up @@ -87,6 +88,12 @@
from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block


class AutoRoundFormat(str, Enum):
# Weight: FP8, per-channel, may be extended to per-tensor in future
# Activation: FP8, per-tensor
TORCH_FP8_STATIC = "torch_fp8_static"


class AutoRound(object):
"""Automatic weight rounding (Signed Gradient Descent) for LLM quantization

Expand Down Expand Up @@ -678,8 +685,8 @@ def _parse_format_to_list(self, format: str) -> list:
format = "auto_round:auto_awq"
elif is_nv_fp(self.data_type) or is_mx_fp(self.data_type):
format = f"auto_round:{self.data_type}"
elif is_wfp8afp8(self): # staic wfp8afp8
format = "auto_round:fp8"
elif is_static_wfp8afp8(self): # staic wfp8afp8
format = f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}"
elif self.data_type == "fp" and self.bits == 8 and self.act_bits >= 16: # woq fp8
format = "auto_round:fp8"
elif self.act_bits < 16:
Expand Down Expand Up @@ -755,10 +762,10 @@ def _check_supported_format(self, format: str) -> bool:
)
format = "fake"
else:
if not (format == "auto_round" or format == "auto_round:fp8"):
if not (format == "auto_round" or format == f"auto_round:{AutoRoundFormat.TORCH_FP8_STATIC.value}"):
logger.warning(
f"Currently only support to export auto_round or fake format for static W{self.bits}AFP8 model,"
" change format to auto_round"
f" change format {format} to auto_round"
)
format = "auto_round"
if self.act_group_size != 0 and not self.act_dynamic and format == "auto_round:fp8":
Expand Down
54 changes: 54 additions & 0 deletions auto_round/experimental/qmodules/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Optional, Union

import torch

__all__ = ["QModuleBase"]


class QModuleBase(torch.nn.Module):
"""
Base class used to describe the weight creation and forward pass
of different quantization schemes supported by Auto-Round.
The design is inspired by vLLM's CompressedTensorsScheme:
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py

"""

def __init__(self):
super().__init__()

@classmethod
@abstractmethod
def from_original(cls, config, original_layer: torch.nn.Module):
raise NotImplementedError

@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError

@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError
119 changes: 119 additions & 0 deletions auto_round/experimental/qmodules/fp8_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
from typing import Optional, Union

import torch

from auto_round.experimental.qmodules.base import QModuleBase
from auto_round.utils import logger

__all__ = ["WeightFP8ActFP8StaticQuantLinear"]


def _quant_tensor_to_fp8_with_scale(tensor: torch.Tensor, scale: torch.Tensor):
FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max
qtensor = tensor / scale
clipped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE)
clipped_qtensor_fp8 = clipped_qtensor.to(torch.float8_e4m3fn)
return scale, clipped_qtensor_fp8


class WeightFP8ActFP8StaticQuantLinear(QModuleBase):
hp_dtype = torch.bfloat16
fp8_dtype = torch.float8_e4m3fn

def __init__(
self,
in_features,
out_features,
weight: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
bias: Union[torch.Tensor, bool, None] = None,
input_scale: Optional[torch.Tensor] = None,
dtype=torch.bfloat16,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
init_weight = torch.zeros((out_features, in_features), dtype=dtype) if weight is None else weight
self.weight = torch.nn.Parameter(init_weight, requires_grad=False)
self.dtype = dtype
if bias is not None:
if isinstance(bias, bool):
bias = torch.zeros((out_features,), dtype=dtype)
self.bias = torch.nn.Parameter(bias, requires_grad=False)
else:
self.register_parameter("bias", None)
init_weight_scale = torch.empty((out_features), dtype=dtype) if weight_scale is None else weight_scale
self.register_buffer("weight_scale", init_weight_scale.to(dtype))

init_input_scale = torch.zeros((1), dtype=dtype) if input_scale is None else input_scale
self.register_buffer("input_scale", init_input_scale.to(dtype))
self.pre_dequantized = False

@classmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
# TODO: correct that config once we add fp8 op support.
logger.warning_once("FP8 ops are not yet supported. Using capability 0.")
return 0

def process_weights_after_loading(self, layer: torch.nn.Module):
pass

@classmethod
def from_original(cls, config, original_layer):
"""
Create an WeightFP8ActFP8StaticQuantLinear layer from an original linear layer.
"""
device = original_layer.weight.device
with torch.device(device):
qdq_linear = cls(
in_features=original_layer.in_features,
out_features=original_layer.out_features,
bias=original_layer.bias,
)
return qdq_linear

def dequant_weight_online(self):
if self.pre_dequantized:
return self.weight
fp8_weight = self.weight
qdq_weight = fp8_weight.to(self.dtype) * self.weight_scale.unsqueeze(1)
return qdq_weight

def pre_dequantize(self):
if self.pre_dequantized:
return
dequant_weight = self.dequant_weight_online()
del self.weight
del self.weight_scale
self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False)
self.pre_dequantized = True

def qdq_input(self, bf16_input: torch.Tensor):
input_scale, input_fp8 = _quant_tensor_to_fp8_with_scale(bf16_input, self.input_scale.data)
qdq_input_bf16 = input_fp8.to(self.dtype) * input_scale
return qdq_input_bf16

@torch.no_grad()
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:
qdq_input = self.qdq_input(bf16_input)
qdq_weight = self.dequant_weight_online()
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
return out
7 changes: 6 additions & 1 deletion auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,14 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround

return save_quantized_as_autoround(output_dir, inplace=inplace, backend="auto_round", **kwargs)
from auto_round.autoround import AutoRoundFormat

##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source
if (kwargs.get("sym") is None or kwargs.get("sym")) and ("gptq" not in backend and "awq" not in backend):
if (
(kwargs.get("sym") is None or kwargs.get("sym"))
and ("gptq" not in backend and "awq" not in backend)
and (AutoRoundFormat.TORCH_FP8_STATIC.value not in backend)
):
backend = backend.replace("auto_round", "auto_round:auto_gptq")

model = kwargs["model"]
Expand Down
28 changes: 26 additions & 2 deletions auto_round/inference/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers.utils.versions import require_version

import auto_round_extension.cuda.gptqmodel_marlin
from auto_round.utils import get_library_version, logger
from auto_round.utils import get_library_version, is_weight_fp8_activation_static_fp8, logger

BackendInfos = {}

Expand Down Expand Up @@ -172,6 +172,22 @@ def feature_multiply_checker_group_size(
requirements=["auto-round>=0.5.1"],
)

# FP8 static quant
# Weight: FP8, per-channel, may be extended to per-tensor in future
# Activation: FP8, per-tensor

BackendInfos["auto_round:torch_fp8_static"] = BackendInfo(
device=["xpu", "cuda", "cpu"],
packing_format="",
sym=[True],
dtype=["float32", "float16", "bfloat16"],
bits=[8],
priority=0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add checkers to differentiate between w8a8-FP8 dynamic, w8a8-int, w8a4, etc.

feature_checks=[],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_size checker is also needed, as mentioned in the comment, only per-channel is supported for now

alias=["auto_round", "torch"],
requirements=["auto-round>=0.6.1.dev0"],
Copy link
Contributor

@wenhuach21 wenhuach21 Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"> 0.6.0"

)

BackendInfos["auto_round:tritonv2_zp"] = BackendInfo(
device=["cuda", "xpu"],
sym=[True], ## asym has accuracys
Expand Down Expand Up @@ -413,7 +429,7 @@ def check_compatible(
return True


def dynamic_import_inference_linear(backend, bits, group_size, sym):
def dynamic_import_inference_linear(backend, config):
"""Dynamically imports and returns the appropriate QuantLinear class based on the given backend.

This function dynamically loads the correct `QuantLinear` class based on the backend and quantization
Expand All @@ -438,6 +454,13 @@ def dynamic_import_inference_linear(backend, bits, group_size, sym):
ImportError:
If required modules are missing for a backend (e.g., Intel Extension, GPTQ, auto_awq).
"""
bits, group_size, sym = config["bits"], config["group_size"], config["sym"]

if is_weight_fp8_activation_static_fp8(config):
from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear

return WeightFP8ActFP8StaticQuantLinear

if "qbits" in backend:
try:
from intel_extension_for_transformers import qbits # pylint: disable=E0401
Expand Down Expand Up @@ -825,6 +848,7 @@ def build_pip_commands(gptq_req, other_reqs):

# Instructional messages
install_instructions = []

for cmd in pip_cmds:
if "intel-extension-for-pytorch" in cmd and target_device == "xpu":
install_instructions.append(
Expand Down
15 changes: 9 additions & 6 deletions auto_round/inference/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
find_backend,
get_highest_priority_backend,
get_layer_backend,
is_weight_fp8_activation_static_fp8,
process_requirement,
)
from auto_round.utils import (
Expand Down Expand Up @@ -61,7 +62,7 @@ def skip_not_convert_modules(model, quantization_config, layer_names, layer_conf
try: # transformers new api
modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert, add_default_skips=True)
except:
modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert)
modules_to_not_convert = _get_modules_to_not_convert(model, modules_to_not_convert)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this

if modules_to_not_convert:
for layer_name in layer_names:
if any([re.search(re.compile(n), layer_name) for n in modules_to_not_convert]):
Expand Down Expand Up @@ -219,6 +220,7 @@ def get_layer_config(model, quantization_config):
- group_size (int): Group size for weight quantization.
- data_type (str, optional): Data type for quantization (default: "int").
- sym (bool): Whether to use symmetric quantization.
- act_dynamic (bool, optional): Whether to use dynamic activation quantization (default: False).
- quant_block_list (list, optional): Predefined list of blocks to quantize.
- to_quant_block_names (list or str, optional): Blocks to quantize (if quant_block_list is None).
- extra_config (dict, optional): Per-layer overrides for quantization settings.
Expand All @@ -231,13 +233,14 @@ def get_layer_config(model, quantization_config):
- "group_size" (int): Group size for quantization.
- "data_type" (str): Data type used for quantization.
- "sym" (bool): Whether symmetric quantization is applied.
- "act_dynamic" (bool): Whether dynamic activation quantization is used.
- "clip" (bool): Whether weight clipping is enabled.
"""
bits = quantization_config.bits
group_size = quantization_config.group_size
data_type = getattr(quantization_config, "data_type", "int") # Default to "int" if not specified
sym = quantization_config.sym

act_dynamic = getattr(quantization_config, "act_dynamic", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you add this one, I think you need to introduce act_bits, act_group_size, xxx too

# Determine the quantization block list
quant_block_list = getattr(quantization_config, "quant_block_list", None)
if quant_block_list is None:
Expand Down Expand Up @@ -290,11 +293,11 @@ def get_layer_config(model, quantization_config):
"group_size": extra_config.get(layer_name, {}).get("group_size", group_size),
"data_type": extra_config.get(layer_name, {}).get("data_type", data_type),
"sym": extra_config.get(layer_name, {}).get("sym", sym),
"act_dynamic": extra_config.get(layer_name, {}).get("act_dynamic", act_dynamic),
"clip": extra_config.get(layer_name, {}).get("clip", False),
}
for layer_name in layer_names
}

return layer_configs


Expand Down Expand Up @@ -415,7 +418,7 @@ def _import_exllamav2_kernels():

def _create_quant_layer(layer, layer_backend, config, in_features, out_features):
"""Creates a quantized layer using the appropriate class."""
QuantLinear = dynamic_import_inference_linear(layer_backend, config["bits"], config["group_size"], config["sym"])
QuantLinear = dynamic_import_inference_linear(layer_backend, config)
bias = layer.bias is not None

# Special handling for AWQ layers
Expand All @@ -437,6 +440,8 @@ def _create_quant_layer(layer, layer_backend, config, in_features, out_features)
out_features=out_features,
bias=bias,
)
elif is_weight_fp8_activation_static_fp8(config):
return QuantLinear.from_original(config, layer)
# Default quantized layer creation
try:
return QuantLinear(
Expand Down Expand Up @@ -561,7 +566,6 @@ def convert_hf_model(model: nn.Module, target_device="cpu"):
backend = quantization_config.backend
else:
backend = "auto"

##target_backend could be None
_, backend = parse_target_device_and_backend(backend)

Expand All @@ -588,7 +592,6 @@ def convert_hf_model(model: nn.Module, target_device="cpu"):
backend = backend[len("auto_round:") :]

used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, orig_backend)

if backend == "auto" or backend == "":
best_backend = get_highest_priority_backend(
quantization_config.bits,
Expand Down
Loading
Loading