Skip to content

Commit ba3b8ce

Browse files
committed
[quant] fix int16 quantization scale in conv weight
Summary: fix int16 quantization scale in conv weight Test Plan: python3 test/test_quantization.py TestQuantizeEagerOps.test_int16_reference_module Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8466547 Pull Request resolved: #74665
1 parent 7f996b8 commit ba3b8ce

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

test/quantization/eager/test_quantize_eager_ptq.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,77 @@ def test_linear(self):
199199
(16, 5)
200200
)
201201

202+
@override_qengines
203+
def test_int16_reference_module(self):
204+
205+
class RefM(torch.nn.Module):
206+
def __init__(self):
207+
super().__init__()
208+
self.conv = nn.ConvTranspose2d(1, 1, 1)
209+
self.quant1 = QuantStub()
210+
self.dequant1 = DeQuantStub()
211+
self.quant2 = QuantStub()
212+
self.dequant2 = DeQuantStub()
213+
214+
def forward(self, x):
215+
x = self.quant1(x)
216+
x = self.dequant1(x)
217+
x = self.conv(x)
218+
x = self.quant2(x)
219+
x = self.dequant2(x)
220+
return x
221+
222+
223+
input_size = (16, 1, 10, 10)
224+
data = torch.randn(*input_size, dtype=torch.float)
225+
226+
original_ref_m = RefM()
227+
rand_w = torch.randn_like(original_ref_m.conv.weight)
228+
rand_b = torch.randn_like(original_ref_m.conv.bias)
229+
original_ref_m.conv.weight = torch.nn.Parameter(rand_w, requires_grad=False)
230+
original_ref_m.conv.bias = torch.nn.Parameter(rand_b, requires_grad=False)
231+
232+
qengine = torch.backends.quantized.engine
233+
if qengine not in supported_qengines:
234+
return
235+
from torch.ao.quantization.observer import MovingAverageMinMaxObserver
236+
237+
weight_obs = MovingAverageMinMaxObserver.with_args(
238+
dtype=torch.qint32,
239+
# set qmin and qmax to represent qint16
240+
quant_min=-1 * (2 ** 15),
241+
quant_max=(2 ** 15) - 1,
242+
qscheme=torch.per_tensor_symmetric,
243+
)
244+
act_obs = MovingAverageMinMaxObserver.with_args(
245+
dtype=torch.qint32,
246+
quant_min=-1 * (2 ** 15),
247+
quant_max=(2 ** 15) - 1,
248+
)
249+
custom_qconfig = QConfig(activation=act_obs, weight=weight_obs)
250+
251+
# quantize the reference model
252+
original_ref_m.eval()
253+
original_ref_m.qconfig = custom_qconfig
254+
255+
ref_m = prepare(original_ref_m)
256+
# calibration
257+
ref_m(torch.randn(*input_size, dtype=torch.float))
258+
259+
ref_m = convert(ref_m, is_reference=True)
260+
261+
myobs = MovingAverageMinMaxObserver(averaging_constant=0.5,
262+
dtype=torch.qint32,
263+
# set qmin and qmax to represent qint16
264+
quant_min=-1 * (2 ** 15),
265+
quant_max=(2 ** 15) - 1,
266+
qscheme=torch.per_tensor_symmetric,
267+
)
268+
result = myobs(rand_w)
269+
qparams = myobs.calculate_qparams()
270+
self.assertEqual(ref_m.conv.weight_scale, qparams[0])
271+
272+
202273
def _test_activation_op_impl(
203274
self, float_module_class, quantized_module_class, extra_module_kwargs):
204275
""" Implementation for testing common activation ops like leaky relu

torch/nn/quantized/_reference/modules/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def _init_weight_qparams(self, weight_qparams, device):
1616
None, torch.per_tensor_affine, torch.per_channel_affine,
1717
torch.per_channel_affine_float_qparams], \
1818
Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}")
19-
if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2]:
19+
if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
2020
zero_point_dtype = weight_qparams["zero_point"].dtype if \
2121
isinstance(weight_qparams["zero_point"], torch.Tensor) else \
2222
torch.int
@@ -35,13 +35,12 @@ def _init_weight_qparams(self, weight_qparams, device):
3535
self.register_buffer(
3636
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
3737
else:
38-
# added for TorchScriptability, not used
38+
# added for TorchScriptability, and for torch.float
3939
self.register_buffer("weight_scale", torch.tensor(1.0, dtype=torch.float, device=device))
4040
self.register_buffer("weight_zero_point", torch.tensor(0, dtype=torch.int, device=device))
4141
self.register_buffer(
4242
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
4343

44-
4544
def get_weight(self):
4645
"""
4746
Fake quantize (quantize and dequantize) the weight with
@@ -105,7 +104,7 @@ def _quantize_weight(
105104
weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
106105
return weight
107106
elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
108-
if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2]:
107+
if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
109108
weight = torch.quantize_per_channel(
110109
weight, weight_scale,
111110
weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type]

0 commit comments

Comments
 (0)