Skip to content

Commit d6426ad

Browse files
committed
migrate prototype/awq to configs
Summary: As titled Test Plan: ``` // note: this fails on weights only load, but the failure happens // after my changes, and already exists on main branch pytest test/prototype/test_awq.py -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a6dceb5a8847e331e4844a6b89f40078c2907b4d ghstack-comment-id: 2706766649 Pull Request resolved: #1853
1 parent 6f11d8f commit d6426ad

File tree

2 files changed

+81
-64
lines changed

2 files changed

+81
-64
lines changed

awq_model.pth

215 KB
Binary file not shown.

torchao/prototype/awq/api.py

+81-64
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
1+
import types
2+
from dataclasses import dataclass
3+
14
import torch
25

6+
from torchao.core.config import AOBaseConfig
37
from torchao.dtypes import (
48
TensorCoreTiledLayout,
59
to_affine_quantized_intx,
610
)
711
from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout
812
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
913
from torchao.quantization.granularity import PerGroup
10-
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
14+
from torchao.quantization.quant_api import (
15+
_linear_extra_repr,
16+
_replace_with_custom_fn_if_matches_filter,
17+
)
1118
from torchao.quantization.quant_primitives import (
1219
_DTYPE_TO_QVALUE_BOUNDS,
1320
MappingType,
1421
ZeroPointDomain,
1522
)
23+
from torchao.quantization.transform_module import (
24+
register_quantize_module_handler,
25+
)
1626

1727
from .core import (
1828
AWQObservedLinear,
@@ -90,80 +100,87 @@ def _observed_linear_subclass_inserter(constructor):
90100
constructor: the function which applies quantization to the AWQObservedLinear layer
91101
"""
92102

93-
def insert_subclass(observed_linear):
94-
# creates the new linear layer using constructor
95-
linear = torch.nn.Linear(
96-
observed_linear.in_features,
97-
observed_linear.out_features,
98-
observed_linear.bias != None,
99-
device=observed_linear.weight.device,
100-
dtype=observed_linear.weight.dtype,
101-
)
102-
linear.weight = torch.nn.Parameter(
103-
constructor(observed_linear), requires_grad=False
104-
)
105-
linear.bias = observed_linear.bias
106-
return linear
107-
108-
return insert_subclass
109-
110103

111-
def awq_uintx(
112-
quant_dtype: torch.dtype = torch.uint4,
113-
group_size: int = 64,
114-
use_hqq: bool = False,
115-
):
104+
@dataclass
105+
class AWQUIntXConfig(AOBaseConfig):
116106
"""
117-
Quantizes linear layers when passed into quantize_()
107+
Configuration for quantizing linear layers when passed into quantize_()
118108
119109
Args:
120110
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
121111
group_size: Quantization granularity. Use -1 for channel wise quantization
122112
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
123113
"""
114+
115+
quant_dtype: torch.dtype = torch.uint4
116+
group_size: int = 64
117+
use_hqq: bool = False
118+
119+
120+
# for bc
121+
awq_uintx = AWQUIntXConfig
122+
123+
124+
@register_quantize_module_handler(AWQUIntXConfig)
125+
def _awq_uintx_transform(
126+
module: torch.nn.Module,
127+
config: AWQUIntXConfig,
128+
) -> torch.nn.Module:
129+
quant_dtype = config.quant_dtype
130+
group_size = config.group_size
131+
use_hqq = config.use_hqq
132+
observed_linear = module
133+
124134
assert (
125135
quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8
126136
), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
127137

128-
def weight_quant_func(observed_linear):
129-
equalization_scale = observed_linear.act_obs.calculate_qparams()
130-
# AQT config
131-
if quant_dtype == torch.uint4:
132-
target_dtype = torch.int32
133-
eps = 1e-6
134-
preserve_zero = False
135-
zero_point_dtype = torch.bfloat16
136-
zero_point_domain = ZeroPointDomain.FLOAT
137-
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
138-
else:
139-
target_dtype = torch.uint8
140-
eps = torch.finfo(torch.float32).eps
141-
preserve_zero = True
142-
zero_point_dtype = torch.int64
143-
zero_point_domain = ZeroPointDomain.INT
144-
_layout = UintxLayout(quant_dtype)
145-
146-
mapping_type = MappingType.ASYMMETRIC
147-
block_size = (1, group_size)
148-
quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0]
149-
quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1]
150-
qw = to_affine_quantized_intx(
151-
observed_linear.weight * equalization_scale,
152-
mapping_type,
153-
block_size,
154-
target_dtype,
155-
quant_min,
156-
quant_max,
157-
eps,
158-
zero_point_dtype=zero_point_dtype,
159-
preserve_zero=preserve_zero,
160-
zero_point_domain=zero_point_domain,
161-
_layout=_layout,
162-
use_hqq=use_hqq,
163-
)
138+
equalization_scale = observed_linear.act_obs.calculate_qparams()
139+
# AQT config
140+
if quant_dtype == torch.uint4:
141+
target_dtype = torch.int32
142+
eps = 1e-6
143+
preserve_zero = False
144+
zero_point_dtype = torch.bfloat16
145+
zero_point_domain = ZeroPointDomain.FLOAT
146+
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
147+
else:
148+
target_dtype = torch.uint8
149+
eps = torch.finfo(torch.float32).eps
150+
preserve_zero = True
151+
zero_point_dtype = torch.int64
152+
zero_point_domain = ZeroPointDomain.INT
153+
_layout = UintxLayout(quant_dtype)
164154

165-
return to_weight_tensor_with_linear_activation_scale_metadata(
166-
qw, equalization_scale
167-
)
155+
mapping_type = MappingType.ASYMMETRIC
156+
block_size = (1, group_size)
157+
quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0]
158+
quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1]
159+
qw = to_affine_quantized_intx(
160+
observed_linear.weight * equalization_scale,
161+
mapping_type,
162+
block_size,
163+
target_dtype,
164+
quant_min,
165+
quant_max,
166+
eps,
167+
zero_point_dtype=zero_point_dtype,
168+
preserve_zero=preserve_zero,
169+
zero_point_domain=zero_point_domain,
170+
_layout=_layout,
171+
use_hqq=use_hqq,
172+
)
173+
174+
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)
168175

169-
return _observed_linear_subclass_inserter(weight_quant_func)
176+
linear = torch.nn.Linear(
177+
observed_linear.in_features,
178+
observed_linear.out_features,
179+
observed_linear.bias != None,
180+
device=observed_linear.weight.device,
181+
dtype=observed_linear.weight.dtype,
182+
)
183+
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
184+
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
185+
linear.bias = observed_linear.bias
186+
return linear

0 commit comments

Comments
 (0)