|
5 | 5 | import torch
|
6 | 6 |
|
7 | 7 | from torchao.prototype.smoothquant import (
|
| 8 | + SmoothQuantConfig, |
8 | 9 | SmoothQuantObservedLinear,
|
9 | 10 | insert_smooth_quant_observer_,
|
10 | 11 | load_smooth_quant_recipe,
|
11 | 12 | save_smooth_quant_recipe,
|
12 |
| - smooth_quant, |
13 | 13 | )
|
14 | 14 | from torchao.quantization import quantize_
|
15 | 15 | from torchao.quantization.utils import (
|
@@ -85,7 +85,7 @@ def forward(self, x):
|
85 | 85 | m(data)
|
86 | 86 | # quantize
|
87 | 87 | is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
|
88 |
| - quantize_(m, smooth_quant(), is_observed_linear) |
| 88 | + quantize_(m, SmoothQuantConfig(), is_observed_linear) |
89 | 89 | with torch.inference_mode():
|
90 | 90 | if TORCH_VERSION_AT_LEAST_2_5:
|
91 | 91 | m = torch.compile(m, fullgraph=True)
|
@@ -173,7 +173,7 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype):
|
173 | 173 |
|
174 | 174 | # quantize
|
175 | 175 | is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
|
176 |
| - quantize_(m, smooth_quant(), is_observed_linear) |
| 176 | + quantize_(m, SmoothQuantConfig(), is_observed_linear) |
177 | 177 | if TORCH_VERSION_AT_LEAST_2_5:
|
178 | 178 | # earlier versions are not compatible
|
179 | 179 | m = torch.compile(m, fullgraph=True)
|
|
0 commit comments