Skip to content

Commit 2c057c5

Browse files
committed
Match torch.fake_quantize numerics in 8da4w QAT
Summary: There are two subtle differences between the 8da4w quant primitives and `torch.fake_quantize_per_channel_affine` today: 1. 8da4w uses float32 zero points torch.fake_quantize uses int32 zero points 2. 8da4w uses input.div(scales) torch.fake_quantize uses input.mul(1.0 / scales) Of these two differences, the second one is smaller and only resulted in 0.1% elements mismatched in unit tests, but it is a source of numerical divergence nonetheless. This commit changes 8da4w QAT quant primitives to match the torch.fake_quantize behavior for both of these differences. In a future commit, we will change the 8da4w PTQ quant primitives as well so PTQ and QAT remain consistent. Test Plan: python test/quantization/test_qat.py -k test_qat_generic_fake_quantize Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar
1 parent 3dd16c9 commit 2c057c5

File tree

3 files changed

+77
-31
lines changed

3 files changed

+77
-31
lines changed

test/quantization/test_qat.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1515
from torchao.quantization.prototype.qat import (
1616
_choose_qparams_per_token_asymmetric,
17+
_GenericFakeQuantize,
1718
fake_quantize_per_channel_group,
1819
fake_quantize_per_token,
1920
)
@@ -58,7 +59,7 @@ def _get_qmin_qmax(self, n_bit: int):
5859
qmax = 2 ** (n_bit - 1) - 1
5960
return (qmin, qmax)
6061

61-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
62+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
6263
def test_fake_quantize_per_channel_group(self):
6364
n_bit = 4
6465
(qmin, qmax) = self._get_qmin_qmax(n_bit)
@@ -67,6 +68,7 @@ def test_fake_quantize_per_channel_group(self):
6768
torch.manual_seed(self.SEED)
6869
x = torch.randn(100, 256).requires_grad_()
6970
(s, zp) = get_group_qparams_symmetric(x, n_bit, group_size)
71+
zp = zp.to(torch.int32)
7072
x2 = copy.deepcopy(x)
7173

7274
# fake quant op
@@ -84,18 +86,15 @@ def test_fake_quantize_per_channel_group(self):
8486
)
8587
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)
8688

87-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
89+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
8890
def test_fake_quantize_per_token(self):
8991
(qmin, qmax) = self._get_qmin_qmax(8)
9092

9193
torch.manual_seed(self.SEED)
9294
x = torch.randn(100, 256).requires_grad_()
9395
x2 = copy.deepcopy(x)
9496
# TODO: use torch.ops.aten.quantized_decomposed version instead
95-
(s, zp) = _choose_qparams_per_token_asymmetric(
96-
x,
97-
torch.int8, # not used
98-
)
97+
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)
9998

10099
# fake quant op
101100
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
@@ -130,7 +129,7 @@ def _set_ptq_weight(
130129
ptq_linear.scales = s
131130
ptq_linear.zeros = zp
132131

133-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
132+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
134133
def test_qat_8da4w_linear(self):
135134
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
136135
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
@@ -155,7 +154,7 @@ def test_qat_8da4w_linear(self):
155154
ptq_out = ptq_linear(x2)
156155
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)
157156

158-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
157+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
159158
def test_qat_8da4w_quantizer(self):
160159
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
161160
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
@@ -189,7 +188,7 @@ def test_qat_8da4w_quantizer(self):
189188
for k in ptq_state_dict.keys():
190189
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
191190

192-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
191+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
193192
def test_qat_8da4w_quantizer_meta_weights(self):
194193
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
195194

@@ -201,7 +200,7 @@ def test_qat_8da4w_quantizer_meta_weights(self):
201200
qat_model = qat_quantizer.prepare(m)
202201
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))
203202

204-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
203+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
205204
def test_qat_8da4w_quantizer_disable_fake_quant(self):
206205
"""
207206
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
@@ -254,7 +253,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
254253
qat_out2 = qat_model2(*x2)
255254
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)
256255

257-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
256+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
258257
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
259258
"""
260259
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
@@ -299,6 +298,30 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
299298
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
300299
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)
301300

301+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
302+
def test_qat_generic_fake_quantize(self):
303+
"""
304+
Test that the generic fake quantize used in 8da4w QAT matches
305+
the numerics of existing fake quantize ops in Pytorch in both
306+
the forward and the backward passes.
307+
"""
308+
(qmin, qmax) = self._get_qmin_qmax(4)
309+
py_input = torch.randn(16, 64).float().requires_grad_()
310+
py_s = torch.randn(16).float()
311+
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
312+
py_out = torch.fake_quantize_per_channel_affine(py_input, py_s, py_zp, 0, qmin, qmax)
313+
py_out.sum().backward()
314+
315+
ao_input = copy.deepcopy(py_input)
316+
ao_input.grad.data.zero_()
317+
ao_s = copy.deepcopy(py_s).reshape(-1, 1)
318+
ao_zp = copy.deepcopy(py_zp).reshape(-1, 1)
319+
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax)
320+
ao_out.sum().backward()
321+
322+
torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
323+
torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0)
324+
302325

303326
if __name__ == "__main__":
304327
unittest.main()

torchao/quantization/prototype/qat.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional, Tuple
7+
from typing import Any, Tuple
88

99
import torch
1010
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
1111
from torch.library import impl
1212

13-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
13+
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
1414
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
1515
from torchao.quantization.unified import TwoStepQuantizer
1616

1717

18-
if TORCH_VERSION_AFTER_2_3:
18+
if TORCH_VERSION_AFTER_2_4:
1919
from torchao.quantization.GPTQ import (
2020
_replace_linear_8da4w,
2121
Int8DynActInt4WeightLinear,
@@ -54,7 +54,7 @@ def prepare(
5454
self.precision,
5555
self.scales_precision,
5656
Int8DynActInt4WeightQATLinear,
57-
copy_weights = True,
57+
copy_weights=True,
5858
)
5959
return model
6060

@@ -95,7 +95,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):
9595
quantized_linear.zeros = zp
9696
else:
9797
_convert_qat_linear_8da4w(child)
98-
98+
9999
class Int8DynActInt4WeightQATLinear(torch.nn.Linear):
100100
"""
101101
This module implements a linear layer with int8 dynamic per token fake
@@ -131,6 +131,8 @@ def __init__(
131131
self.groupsize = groupsize
132132
self.precision = precision
133133
self.scales_precision = scales_precision
134+
# TODO: make this configurable?
135+
self.zero_points_precision = torch.int32
134136
self._fake_quant_enabled = True
135137

136138
def enable_fake_quant(self, enabled: bool = True):
@@ -142,8 +144,8 @@ def disable_fake_quant(self):
142144
def forward(self, x: torch.Tensor) -> torch.Tensor:
143145
# activations: int8 dynamic asymmetric quant
144146
if self._fake_quant_enabled:
145-
(act_scales, act_zp) =_choose_qparams_per_token_asymmetric(
146-
x, torch.int8, # dtype not used
147+
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
148+
x, self.scales_precision, self.zero_points_precision,
147149
)
148150
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
149151
x_fq = fake_quantize_per_token(
@@ -157,6 +159,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
157159
(weight_scales, weight_zp) = get_group_qparams_symmetric(
158160
self.weight, 4, self.groupsize, self.scales_precision,
159161
)
162+
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
163+
weight_zp = weight_zp.to(self.zero_points_precision)
160164
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
161165
w_fq = fake_quantize_per_channel_group(
162166
self.weight,
@@ -190,6 +194,20 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module):
190194
if isinstance(mod, Int8DynActInt4WeightQATLinear):
191195
mod.disable_fake_quant()
192196

197+
else: # not TORCH_VERSION_AFTER_2_4
198+
199+
class Int8DynActInt4WeightQATQuantizer:
200+
def __init__(*args, **kwargs):
201+
raise ValueError(
202+
"Int8DynActInt4WeightQATQuantizer is only supported after PyTorch 2.4+"
203+
)
204+
205+
class Int8DynActInt4WeightQATLinear:
206+
def __init__(*args, **kwargs):
207+
raise ValueError(
208+
"Int8DynActInt4WeightQATQuantizer is only supported after PyTorch 2.4+"
209+
)
210+
193211

194212
# ========================
195213
# | QUANT PRIMITIVES |
@@ -205,13 +223,15 @@ class _GenericFakeQuantize(torch.autograd.Function):
205223

206224
@staticmethod
207225
def forward(ctx, input, scales, zero_points, quant_min, quant_max):
226+
assert input.dtype == torch.float32
227+
assert scales.dtype == torch.float32
228+
assert zero_points.dtype == torch.int32
208229
# Note: this diverges from `torch.fake_quantize_per_channel_affine`,
209-
# which rounds first before adding the zero points. However, this
210-
# is what `quantize_per_channel_group` and `quantize_per_token`
211-
# do and here we try to match that behavior as closely as possible.
230+
# which rounds first before adding the zero points. However, since
231+
# zero points are integers here, the ordering of these two ops
232+
# shouldn't matter in practice.
212233
q = input.mul(1.0 / scales).add(zero_points).round()
213234
dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales)
214-
# TODO: do we need this mask?
215235
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
216236
ctx.save_for_backward(mask)
217237
return dq
@@ -239,14 +259,13 @@ def fake_quantize_per_channel_group(
239259
assert group_size > 1
240260
assert input.shape[-1] % group_size == 0
241261
assert input.dim() == 2
242-
assert torch.isnan(input).sum() == 0
243-
grouped_input = input.reshape(-1, group_size)
262+
grouped_input = input.reshape(-1, group_size).to(torch.float32)
244263
scales = scales.reshape(-1, 1)
245264
zero_points = zero_points.reshape(-1, 1)
246265
fq = _GenericFakeQuantize.apply(
247266
grouped_input, scales, zero_points, quant_min, quant_max,
248267
)
249-
return fq.reshape_as(input)
268+
return fq.reshape_as(input).to(input.dtype)
250269

251270
# TODO: move this to core
252271
quantized_decomposed_lib.define(
@@ -266,17 +285,20 @@ def fake_quantize_per_token(
266285
from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check
267286

268287
_per_token_quant_qparam_dim_check(input, scales, zero_points)
269-
return _GenericFakeQuantize.apply(
270-
input, scales, zero_points, quant_min, quant_max,
288+
fq_input = input.to(torch.float32)
289+
fq = _GenericFakeQuantize.apply(
290+
fq_input, scales, zero_points, quant_min, quant_max,
271291
)
292+
return fq.reshape_as(input).to(input.dtype)
272293

273294
# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
274295
# The version in pytorch does not have backward support yet so we add
275296
# it here for now until https://github.com/pytorch/pytorch/pull/123452
276297
# is landed.
277298
def _choose_qparams_per_token_asymmetric(
278299
input: torch.Tensor,
279-
dtype: torch.dtype,
300+
scales_precision: torch.dtype = torch.float32,
301+
zero_points_precision: torch.dtype = torch.float32,
280302
) -> Tuple[torch.Tensor, torch.Tensor]:
281303
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
282304
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
@@ -285,7 +307,8 @@ def _choose_qparams_per_token_asymmetric(
285307
286308
Args:
287309
input (torch.Tensor): original float32/float16 Tensor
288-
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
310+
scales_precision (torch.dtype): precision of returned scales
311+
zero_points_precision (torch.dtype): precision of returned zero points
289312
290313
Returns:
291314
scales and zero_points, both float32 Tensors
@@ -314,4 +337,4 @@ def _choose_qparams_per_token_asymmetric(
314337
)
315338
zero_point = torch.clamp(zero_point, qmin, qmax).round()
316339

317-
return scale.to(torch.float32), zero_point.to(torch.float32)
340+
return scale.to(scales_precision), zero_point.to(zero_points_precision)

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ def groupwise_affine_dequantize_tensor(
764764
)
765765

766766

767-
# TODO: replace this with torch.ao.quantization.PerChannelMinMaxObserver
767+
# TODO: separate scale and zero point precision
768768
def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32):
769769
# needed for GPTQ with padding
770770
if groupsize > w.shape[-1]:

0 commit comments

Comments
 (0)