Skip to content

Commit c31a537

Browse files
committed
config migration: smoothquant
Summary: Migrates smoothquant to direct configs Test Plan: ``` pytest test/prototype/test_smoothquant.py -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 2d881f3 ghstack-comment-id: 2706636586 Pull Request resolved: #1851
1 parent ffb4350 commit c31a537

File tree

1 file changed

+86
-67
lines changed
  • torchao/prototype/smoothquant

1 file changed

+86
-67
lines changed

torchao/prototype/smoothquant/api.py

Lines changed: 86 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
1+
import types
2+
from dataclasses import dataclass
13
from typing import Dict, Optional
24

35
import torch
46

7+
from torchao.core.config import AOBaseConfig
58
from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static
69
from torchao.prototype.smoothquant.core import (
710
SmoothQuantObservedLinear,
811
SmoothQuantObserver,
912
)
13+
from torchao.quantization import quantize_
1014
from torchao.quantization.linear_activation_quantized_tensor import (
1115
to_linear_activation_quantized,
1216
)
1317
from torchao.quantization.linear_activation_scale import (
1418
to_weight_tensor_with_linear_activation_scale_metadata,
1519
)
16-
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
20+
from torchao.quantization.quant_api import (
21+
_linear_extra_repr,
22+
_replace_with_custom_fn_if_matches_filter,
23+
)
1724
from torchao.quantization.quant_primitives import MappingType
25+
from torchao.quantization.transform_module import (
26+
register_quantize_module_handler,
27+
)
1828
from torchao.quantization.utils import _get_per_token_block_size
1929
from torchao.quantization.weight_tensor_linear_activation_quantization import (
2030
to_weight_tensor_with_linear_activation_quantization_metadata,
@@ -53,32 +63,6 @@ def replace_with_observer(layer):
5363
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
5464

5565

56-
def _observed_linear_subclass_inserter(constructor):
57-
"""
58-
Replaces unquantized observed linear instances with quantized linear instances.
59-
60-
Args:
61-
constructor: the function which applies quantization to the observed linear layer
62-
"""
63-
64-
def insert_subclass(observed_linear):
65-
# creates the new linear layer using constructor
66-
linear = torch.nn.Linear(
67-
observed_linear.in_features,
68-
observed_linear.out_features,
69-
observed_linear.bias is not None,
70-
device=observed_linear.weight.device,
71-
dtype=observed_linear.weight.dtype,
72-
)
73-
linear.weight = torch.nn.Parameter(
74-
constructor(observed_linear), requires_grad=False
75-
)
76-
linear.bias = observed_linear.bias
77-
return linear
78-
79-
return insert_subclass
80-
81-
8266
def save_smooth_quant_recipe(
8367
model: torch.nn.Module, save_path: str
8468
) -> Dict[str, torch.Tensor]:
@@ -121,7 +105,14 @@ def recurse(module: torch.nn.Module, name: str = ""):
121105
# act_scales is None for dynamic quantization
122106
if any(x is None for x in (smoothing_factor, wei_scales)):
123107
return module
124-
return smooth_quant(smoothing_factor, act_scales, wei_scales)(module)
108+
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
109+
wrapper = torch.nn.Sequential(module)
110+
quantize_(
111+
wrapper,
112+
smooth_quant(smoothing_factor, act_scales, wei_scales),
113+
is_observed_linear,
114+
)
115+
return wrapper[0]
125116

126117
mod_new = module
127118

@@ -158,54 +149,82 @@ def static_quantize(self, input, scale, zero_point):
158149
)
159150

160151

161-
def smooth_quant(
162-
smoothing_factor: Optional[torch.Tensor] = None,
163-
act_scales: Optional[torch.Tensor] = None,
164-
wei_scales: Optional[torch.Tensor] = None,
165-
):
152+
@dataclass
153+
class SmoothQuantConfig(AOBaseConfig):
166154
"""
167-
Quantizes linear layers when passed into quantize_()
155+
Configuration for quantizing linear layers when passed into quantize_()
168156
169157
Args:
170158
smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None.
171159
act_scales: The activation scales for the layer. Acquired from the layer's observer if None.
172160
wei_scales: The weight scales for the layer. Acquired from the layer's observer if None.
173161
"""
174162

175-
def quantize_weight(observed_linear):
176-
target_dtype = torch.int8
177-
# act_scales is None for dynamic quantization thus not checked
178-
if any(x is None for x in (smoothing_factor, wei_scales)):
179-
factor, x_scale, w_scales = observed_linear.obs.calculate_qparams()
180-
weight = observed_linear.obs.weight * factor
181-
else:
182-
factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales
183-
weight = observed_linear.weight * factor
184-
weight = weight.to(observed_linear.weight.dtype)
185-
block_size = (1, weight.size(1))
186-
wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64)
187-
qw = to_affine_quantized_intx_static(
188-
weight,
189-
w_scales,
190-
wei_zero_points,
191-
block_size,
192-
target_dtype,
193-
)
163+
smoothing_factor: Optional[torch.Tensor] = None
164+
act_scales: Optional[torch.Tensor] = None
165+
wei_scales: Optional[torch.Tensor] = None
166+
167+
168+
# for bc
169+
smooth_quant = SmoothQuantConfig
194170

195-
if x_scale is None:
196-
# dynamic quant
197-
qw = to_linear_activation_quantized(
198-
qw, _ActQuantizer(target_dtype).dynamic_quantize
199-
)
200-
else:
201-
# static quant
202-
x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64)
203-
qw = to_weight_tensor_with_linear_activation_quantization_metadata(
204-
qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point
205-
)
206171

207-
return to_weight_tensor_with_linear_activation_scale_metadata(
208-
qw, factor.to(qw.dtype)
172+
@register_quantize_module_handler(SmoothQuantConfig)
173+
def _smooth_quant_transform(
174+
module: torch.nn.Module,
175+
config: SmoothQuantConfig,
176+
):
177+
smoothing_factor = config.smoothing_factor
178+
act_scales = config.act_scales
179+
wei_scales = config.wei_scales
180+
# weight = module.weight
181+
observed_linear = module
182+
183+
linear = torch.nn.Linear(
184+
observed_linear.in_features,
185+
observed_linear.out_features,
186+
observed_linear.bias is not None,
187+
device=observed_linear.weight.device,
188+
dtype=observed_linear.weight.dtype,
189+
)
190+
# linear.weight = torch.nn.Parameter(
191+
# constructor(observed_linear), requires_grad=False
192+
# )
193+
linear.bias = observed_linear.bias
194+
# return linear
195+
196+
target_dtype = torch.int8
197+
# act_scales is None for dynamic quantization thus not checked
198+
if any(x is None for x in (smoothing_factor, wei_scales)):
199+
factor, x_scale, w_scales = observed_linear.obs.calculate_qparams()
200+
weight = observed_linear.obs.weight * factor
201+
else:
202+
factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales
203+
weight = observed_linear.weight * factor
204+
weight = weight.to(observed_linear.weight.dtype)
205+
block_size = (1, weight.size(1))
206+
wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64)
207+
qw = to_affine_quantized_intx_static(
208+
weight,
209+
w_scales,
210+
wei_zero_points,
211+
block_size,
212+
target_dtype,
213+
)
214+
215+
if x_scale is None:
216+
# dynamic quant
217+
qw = to_linear_activation_quantized(
218+
qw, _ActQuantizer(target_dtype).dynamic_quantize
219+
)
220+
else:
221+
# static quant
222+
x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64)
223+
qw = to_weight_tensor_with_linear_activation_quantization_metadata(
224+
qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point
209225
)
210226

211-
return _observed_linear_subclass_inserter(quantize_weight)
227+
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype))
228+
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
229+
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
230+
return linear

0 commit comments

Comments
 (0)