|
| 1 | +import types |
| 2 | +from dataclasses import dataclass |
| 3 | + |
1 | 4 | import torch
|
2 | 5 |
|
| 6 | +from torchao.core.config import AOBaseConfig |
3 | 7 | from torchao.dtypes import (
|
4 | 8 | TensorCoreTiledLayout,
|
5 | 9 | to_affine_quantized_intx,
|
6 | 10 | )
|
7 | 11 | from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout
|
8 | 12 | from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
|
9 | 13 | from torchao.quantization.granularity import PerGroup
|
10 |
| -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter |
| 14 | +from torchao.quantization.quant_api import ( |
| 15 | + _linear_extra_repr, |
| 16 | + _replace_with_custom_fn_if_matches_filter, |
| 17 | +) |
11 | 18 | from torchao.quantization.quant_primitives import (
|
12 | 19 | _DTYPE_TO_QVALUE_BOUNDS,
|
13 | 20 | MappingType,
|
14 | 21 | ZeroPointDomain,
|
15 | 22 | )
|
| 23 | +from torchao.quantization.transform_module import ( |
| 24 | + register_quantize_module_handler, |
| 25 | +) |
16 | 26 |
|
17 | 27 | from .core import (
|
18 | 28 | AWQObservedLinear,
|
@@ -90,80 +100,87 @@ def _observed_linear_subclass_inserter(constructor):
|
90 | 100 | constructor: the function which applies quantization to the AWQObservedLinear layer
|
91 | 101 | """
|
92 | 102 |
|
93 |
| - def insert_subclass(observed_linear): |
94 |
| - # creates the new linear layer using constructor |
95 |
| - linear = torch.nn.Linear( |
96 |
| - observed_linear.in_features, |
97 |
| - observed_linear.out_features, |
98 |
| - observed_linear.bias != None, |
99 |
| - device=observed_linear.weight.device, |
100 |
| - dtype=observed_linear.weight.dtype, |
101 |
| - ) |
102 |
| - linear.weight = torch.nn.Parameter( |
103 |
| - constructor(observed_linear), requires_grad=False |
104 |
| - ) |
105 |
| - linear.bias = observed_linear.bias |
106 |
| - return linear |
107 |
| - |
108 |
| - return insert_subclass |
109 |
| - |
110 | 103 |
|
111 |
| -def awq_uintx( |
112 |
| - quant_dtype: torch.dtype = torch.uint4, |
113 |
| - group_size: int = 64, |
114 |
| - use_hqq: bool = False, |
115 |
| -): |
| 104 | +@dataclass |
| 105 | +class AWQUIntXConfig(AOBaseConfig): |
116 | 106 | """
|
117 |
| - Quantizes linear layers when passed into quantize_() |
| 107 | + Configuration for quantizing linear layers when passed into quantize_() |
118 | 108 |
|
119 | 109 | Args:
|
120 | 110 | quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
|
121 | 111 | group_size: Quantization granularity. Use -1 for channel wise quantization
|
122 | 112 | weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
|
123 | 113 | """
|
| 114 | + |
| 115 | + quant_dtype: torch.dtype = torch.uint4 |
| 116 | + group_size: int = 64 |
| 117 | + use_hqq: bool = False |
| 118 | + |
| 119 | + |
| 120 | +# for bc |
| 121 | +awq_uintx = AWQUIntXConfig |
| 122 | + |
| 123 | + |
| 124 | +@register_quantize_module_handler(AWQUIntXConfig) |
| 125 | +def _awq_uintx_transform( |
| 126 | + module: torch.nn.Module, |
| 127 | + config: AWQUIntXConfig, |
| 128 | +) -> torch.nn.Module: |
| 129 | + quant_dtype = config.quant_dtype |
| 130 | + group_size = config.group_size |
| 131 | + use_hqq = config.use_hqq |
| 132 | + observed_linear = module |
| 133 | + |
124 | 134 | assert (
|
125 | 135 | quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8
|
126 | 136 | ), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
|
127 | 137 |
|
128 |
| - def weight_quant_func(observed_linear): |
129 |
| - equalization_scale = observed_linear.act_obs.calculate_qparams() |
130 |
| - # AQT config |
131 |
| - if quant_dtype == torch.uint4: |
132 |
| - target_dtype = torch.int32 |
133 |
| - eps = 1e-6 |
134 |
| - preserve_zero = False |
135 |
| - zero_point_dtype = torch.bfloat16 |
136 |
| - zero_point_domain = ZeroPointDomain.FLOAT |
137 |
| - _layout = TensorCoreTiledLayout(inner_k_tiles=8) |
138 |
| - else: |
139 |
| - target_dtype = torch.uint8 |
140 |
| - eps = torch.finfo(torch.float32).eps |
141 |
| - preserve_zero = True |
142 |
| - zero_point_dtype = torch.int64 |
143 |
| - zero_point_domain = ZeroPointDomain.INT |
144 |
| - _layout = UintxLayout(quant_dtype) |
145 |
| - |
146 |
| - mapping_type = MappingType.ASYMMETRIC |
147 |
| - block_size = (1, group_size) |
148 |
| - quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] |
149 |
| - quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] |
150 |
| - qw = to_affine_quantized_intx( |
151 |
| - observed_linear.weight * equalization_scale, |
152 |
| - mapping_type, |
153 |
| - block_size, |
154 |
| - target_dtype, |
155 |
| - quant_min, |
156 |
| - quant_max, |
157 |
| - eps, |
158 |
| - zero_point_dtype=zero_point_dtype, |
159 |
| - preserve_zero=preserve_zero, |
160 |
| - zero_point_domain=zero_point_domain, |
161 |
| - _layout=_layout, |
162 |
| - use_hqq=use_hqq, |
163 |
| - ) |
| 138 | + equalization_scale = observed_linear.act_obs.calculate_qparams() |
| 139 | + # AQT config |
| 140 | + if quant_dtype == torch.uint4: |
| 141 | + target_dtype = torch.int32 |
| 142 | + eps = 1e-6 |
| 143 | + preserve_zero = False |
| 144 | + zero_point_dtype = torch.bfloat16 |
| 145 | + zero_point_domain = ZeroPointDomain.FLOAT |
| 146 | + _layout = TensorCoreTiledLayout(inner_k_tiles=8) |
| 147 | + else: |
| 148 | + target_dtype = torch.uint8 |
| 149 | + eps = torch.finfo(torch.float32).eps |
| 150 | + preserve_zero = True |
| 151 | + zero_point_dtype = torch.int64 |
| 152 | + zero_point_domain = ZeroPointDomain.INT |
| 153 | + _layout = UintxLayout(quant_dtype) |
164 | 154 |
|
165 |
| - return to_weight_tensor_with_linear_activation_scale_metadata( |
166 |
| - qw, equalization_scale |
167 |
| - ) |
| 155 | + mapping_type = MappingType.ASYMMETRIC |
| 156 | + block_size = (1, group_size) |
| 157 | + quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0] |
| 158 | + quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1] |
| 159 | + qw = to_affine_quantized_intx( |
| 160 | + observed_linear.weight * equalization_scale, |
| 161 | + mapping_type, |
| 162 | + block_size, |
| 163 | + target_dtype, |
| 164 | + quant_min, |
| 165 | + quant_max, |
| 166 | + eps, |
| 167 | + zero_point_dtype=zero_point_dtype, |
| 168 | + preserve_zero=preserve_zero, |
| 169 | + zero_point_domain=zero_point_domain, |
| 170 | + _layout=_layout, |
| 171 | + use_hqq=use_hqq, |
| 172 | + ) |
| 173 | + |
| 174 | + qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) |
168 | 175 |
|
169 |
| - return _observed_linear_subclass_inserter(weight_quant_func) |
| 176 | + linear = torch.nn.Linear( |
| 177 | + observed_linear.in_features, |
| 178 | + observed_linear.out_features, |
| 179 | + observed_linear.bias != None, |
| 180 | + device=observed_linear.weight.device, |
| 181 | + dtype=observed_linear.weight.dtype, |
| 182 | + ) |
| 183 | + linear.weight = torch.nn.Parameter(qw, requires_grad=False) |
| 184 | + linear.extra_repr = types.MethodType(_linear_extra_repr, module) |
| 185 | + linear.bias = observed_linear.bias |
| 186 | + return linear |
0 commit comments