-
Notifications
You must be signed in to change notification settings - Fork 52
Support loading for static quant weight fp8 act fp8 #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bb94782
9bef826
b30a126
eaad3a6
c411ca5
9802313
6597d5c
9b0f32f
c32daa6
c136339
5ebca24
03cb217
e7280f6
66388e5
17ddd2d
808449d
f74ed6f
632cf8a
5b8b29d
57b4c19
bdf5f3e
ce3384f
7cea90e
22d11de
9082613
6503355
b687633
2202856
10f5753
cc42e47
d0b99a8
31845d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
__all__ = ["QModuleBase"] | ||
|
||
|
||
class QModuleBase(torch.nn.Module): | ||
""" | ||
Base class used to describe the weight creation and forward pass | ||
of different quantization schemes supported by Auto-Round. | ||
The design is inspired by vLLM's CompressedTensorsScheme: | ||
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py | ||
|
||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
@classmethod | ||
@abstractmethod | ||
def from_original(cls, config, original_layer: torch.nn.Module): | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
@abstractmethod | ||
def get_min_capability(cls) -> int: | ||
""" | ||
Get minimum device capability. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def process_weights_after_loading(self, layer: torch.nn.Module): | ||
""" | ||
Called after weight loading is complete for any cleanup that | ||
needs to occur. | ||
""" | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright (c) 2025 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import abstractmethod | ||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
from auto_round.experimental.qmodules.base import QModuleBase | ||
from auto_round.utils import logger | ||
|
||
__all__ = ["WeightFP8ActFP8StaticQuantLinear"] | ||
|
||
|
||
def _quant_tensor_to_fp8_with_scale(tensor: torch.Tensor, scale: torch.Tensor): | ||
FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max | ||
qtensor = tensor / scale | ||
clipped_qtensor = torch.clamp(qtensor, -FULL_RANGE, FULL_RANGE) | ||
clipped_qtensor_fp8 = clipped_qtensor.to(torch.float8_e4m3fn) | ||
return scale, clipped_qtensor_fp8 | ||
|
||
|
||
class WeightFP8ActFP8StaticQuantLinear(QModuleBase): | ||
hp_dtype = torch.bfloat16 | ||
fp8_dtype = torch.float8_e4m3fn | ||
|
||
def __init__( | ||
self, | ||
in_features, | ||
out_features, | ||
weight: Optional[torch.Tensor] = None, | ||
weight_scale: Optional[torch.Tensor] = None, | ||
bias: Union[torch.Tensor, bool, None] = None, | ||
input_scale: Optional[torch.Tensor] = None, | ||
dtype=torch.bfloat16, | ||
): | ||
super().__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
init_weight = torch.zeros((out_features, in_features), dtype=dtype) if weight is None else weight | ||
self.weight = torch.nn.Parameter(init_weight, requires_grad=False) | ||
self.dtype = dtype | ||
if bias is not None: | ||
if isinstance(bias, bool): | ||
bias = torch.zeros((out_features,), dtype=dtype) | ||
self.bias = torch.nn.Parameter(bias, requires_grad=False) | ||
else: | ||
self.register_parameter("bias", None) | ||
init_weight_scale = torch.empty((out_features), dtype=dtype) if weight_scale is None else weight_scale | ||
self.register_buffer("weight_scale", init_weight_scale.to(dtype)) | ||
|
||
init_input_scale = torch.zeros((1), dtype=dtype) if input_scale is None else input_scale | ||
self.register_buffer("input_scale", init_input_scale.to(dtype)) | ||
self.pre_dequantized = False | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
""" | ||
Get minimum device capability. | ||
""" | ||
# TODO: correct that config once we add fp8 op support. | ||
logger.warning_once("FP8 ops are not yet supported. Using capability 0.") | ||
return 0 | ||
|
||
def process_weights_after_loading(self, layer: torch.nn.Module): | ||
pass | ||
|
||
@classmethod | ||
def from_original(cls, config, original_layer): | ||
""" | ||
Create an WeightFP8ActFP8StaticQuantLinear layer from an original linear layer. | ||
""" | ||
device = original_layer.weight.device | ||
with torch.device(device): | ||
qdq_linear = cls( | ||
in_features=original_layer.in_features, | ||
out_features=original_layer.out_features, | ||
bias=original_layer.bias, | ||
) | ||
return qdq_linear | ||
|
||
def dequant_weight_online(self): | ||
if self.pre_dequantized: | ||
return self.weight | ||
fp8_weight = self.weight | ||
qdq_weight = fp8_weight.to(self.dtype) * self.weight_scale.unsqueeze(1) | ||
return qdq_weight | ||
|
||
def pre_dequantize(self): | ||
if self.pre_dequantized: | ||
return | ||
dequant_weight = self.dequant_weight_online() | ||
del self.weight | ||
del self.weight_scale | ||
self.weight = torch.nn.Parameter(dequant_weight, requires_grad=False) | ||
self.pre_dequantized = True | ||
|
||
def qdq_input(self, bf16_input: torch.Tensor): | ||
input_scale, input_fp8 = _quant_tensor_to_fp8_with_scale(bf16_input, self.input_scale.data) | ||
qdq_input_bf16 = input_fp8.to(self.dtype) * input_scale | ||
return qdq_input_bf16 | ||
|
||
@torch.no_grad() | ||
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor: | ||
qdq_input = self.qdq_input(bf16_input) | ||
qdq_weight = self.dequant_weight_online() | ||
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias) | ||
return out |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
from transformers.utils.versions import require_version | ||
|
||
import auto_round_extension.cuda.gptqmodel_marlin | ||
from auto_round.utils import get_library_version, logger | ||
from auto_round.utils import get_library_version, is_weight_fp8_activation_static_fp8, logger | ||
|
||
BackendInfos = {} | ||
|
||
|
@@ -172,6 +172,22 @@ def feature_multiply_checker_group_size( | |
requirements=["auto-round>=0.5.1"], | ||
) | ||
|
||
# FP8 static quant | ||
# Weight: FP8, per-channel, may be extended to per-tensor in future | ||
# Activation: FP8, per-tensor | ||
|
||
BackendInfos["auto_round:torch_fp8_static"] = BackendInfo( | ||
device=["xpu", "cuda", "cpu"], | ||
packing_format="", | ||
sym=[True], | ||
dtype=["float32", "float16", "bfloat16"], | ||
bits=[8], | ||
priority=0, | ||
feature_checks=[], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. group_size checker is also needed, as mentioned in the comment, only per-channel is supported for now |
||
alias=["auto_round", "torch"], | ||
requirements=["auto-round>=0.6.1.dev0"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "> 0.6.0" |
||
) | ||
|
||
BackendInfos["auto_round:tritonv2_zp"] = BackendInfo( | ||
device=["cuda", "xpu"], | ||
sym=[True], ## asym has accuracys | ||
|
@@ -413,7 +429,7 @@ def check_compatible( | |
return True | ||
|
||
|
||
def dynamic_import_inference_linear(backend, bits, group_size, sym): | ||
def dynamic_import_inference_linear(backend, config): | ||
"""Dynamically imports and returns the appropriate QuantLinear class based on the given backend. | ||
|
||
This function dynamically loads the correct `QuantLinear` class based on the backend and quantization | ||
|
@@ -438,6 +454,13 @@ def dynamic_import_inference_linear(backend, bits, group_size, sym): | |
ImportError: | ||
If required modules are missing for a backend (e.g., Intel Extension, GPTQ, auto_awq). | ||
""" | ||
bits, group_size, sym = config["bits"], config["group_size"], config["sym"] | ||
|
||
if is_weight_fp8_activation_static_fp8(config): | ||
from auto_round.experimental.qmodules.fp8_static import WeightFP8ActFP8StaticQuantLinear | ||
|
||
return WeightFP8ActFP8StaticQuantLinear | ||
|
||
if "qbits" in backend: | ||
try: | ||
from intel_extension_for_transformers import qbits # pylint: disable=E0401 | ||
|
@@ -825,6 +848,7 @@ def build_pip_commands(gptq_req, other_reqs): | |
|
||
# Instructional messages | ||
install_instructions = [] | ||
|
||
for cmd in pip_cmds: | ||
if "intel-extension-for-pytorch" in cmd and target_device == "xpu": | ||
install_instructions.append( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
find_backend, | ||
get_highest_priority_backend, | ||
get_layer_backend, | ||
is_weight_fp8_activation_static_fp8, | ||
process_requirement, | ||
) | ||
from auto_round.utils import ( | ||
|
@@ -61,7 +62,7 @@ def skip_not_convert_modules(model, quantization_config, layer_names, layer_conf | |
try: # transformers new api | ||
modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert, add_default_skips=True) | ||
except: | ||
modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert) | ||
modules_to_not_convert = _get_modules_to_not_convert(model, modules_to_not_convert) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why change this |
||
if modules_to_not_convert: | ||
for layer_name in layer_names: | ||
if any([re.search(re.compile(n), layer_name) for n in modules_to_not_convert]): | ||
|
@@ -219,6 +220,7 @@ def get_layer_config(model, quantization_config): | |
- group_size (int): Group size for weight quantization. | ||
- data_type (str, optional): Data type for quantization (default: "int"). | ||
- sym (bool): Whether to use symmetric quantization. | ||
- act_dynamic (bool, optional): Whether to use dynamic activation quantization (default: False). | ||
- quant_block_list (list, optional): Predefined list of blocks to quantize. | ||
- to_quant_block_names (list or str, optional): Blocks to quantize (if quant_block_list is None). | ||
- extra_config (dict, optional): Per-layer overrides for quantization settings. | ||
|
@@ -231,13 +233,14 @@ def get_layer_config(model, quantization_config): | |
- "group_size" (int): Group size for quantization. | ||
- "data_type" (str): Data type used for quantization. | ||
- "sym" (bool): Whether symmetric quantization is applied. | ||
- "act_dynamic" (bool): Whether dynamic activation quantization is used. | ||
- "clip" (bool): Whether weight clipping is enabled. | ||
""" | ||
bits = quantization_config.bits | ||
group_size = quantization_config.group_size | ||
data_type = getattr(quantization_config, "data_type", "int") # Default to "int" if not specified | ||
sym = quantization_config.sym | ||
|
||
act_dynamic = getattr(quantization_config, "act_dynamic", False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you add this one, I think you need to introduce act_bits, act_group_size, xxx too |
||
# Determine the quantization block list | ||
quant_block_list = getattr(quantization_config, "quant_block_list", None) | ||
if quant_block_list is None: | ||
|
@@ -290,11 +293,11 @@ def get_layer_config(model, quantization_config): | |
"group_size": extra_config.get(layer_name, {}).get("group_size", group_size), | ||
"data_type": extra_config.get(layer_name, {}).get("data_type", data_type), | ||
"sym": extra_config.get(layer_name, {}).get("sym", sym), | ||
"act_dynamic": extra_config.get(layer_name, {}).get("act_dynamic", act_dynamic), | ||
"clip": extra_config.get(layer_name, {}).get("clip", False), | ||
} | ||
for layer_name in layer_names | ||
} | ||
|
||
return layer_configs | ||
|
||
|
||
|
@@ -415,7 +418,7 @@ def _import_exllamav2_kernels(): | |
|
||
def _create_quant_layer(layer, layer_backend, config, in_features, out_features): | ||
"""Creates a quantized layer using the appropriate class.""" | ||
QuantLinear = dynamic_import_inference_linear(layer_backend, config["bits"], config["group_size"], config["sym"]) | ||
QuantLinear = dynamic_import_inference_linear(layer_backend, config) | ||
bias = layer.bias is not None | ||
|
||
# Special handling for AWQ layers | ||
|
@@ -437,6 +440,8 @@ def _create_quant_layer(layer, layer_backend, config, in_features, out_features) | |
out_features=out_features, | ||
bias=bias, | ||
) | ||
elif is_weight_fp8_activation_static_fp8(config): | ||
return QuantLinear.from_original(config, layer) | ||
# Default quantized layer creation | ||
try: | ||
return QuantLinear( | ||
|
@@ -561,7 +566,6 @@ def convert_hf_model(model: nn.Module, target_device="cpu"): | |
backend = quantization_config.backend | ||
else: | ||
backend = "auto" | ||
|
||
##target_backend could be None | ||
_, backend = parse_target_device_and_backend(backend) | ||
|
||
|
@@ -588,7 +592,6 @@ def convert_hf_model(model: nn.Module, target_device="cpu"): | |
backend = backend[len("auto_round:") :] | ||
|
||
used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, orig_backend) | ||
|
||
if backend == "auto" or backend == "": | ||
best_backend = get_highest_priority_backend( | ||
quantization_config.bits, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add checkers to differentiate between w8a8-FP8 dynamic, w8a8-int, w8a4, etc.