Skip to content

Commit 802199f

Browse files
Fake quantization enhancements for QAT/PTQ support
Pull Request resolved: #26420 Flags for enabling/disabling observer and fake quant independently. Improve repr for fake quant. ghstack-source-id: 90692999 Differential Revision: [D17458232](https://our.internmc.facebook.com/intern/diff/D17458232/)
1 parent a76dc2c commit 802199f

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

torch/quantization/fake_quantize.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
22
import torch
33
from torch.nn import Module
4-
from .observer import default_observer, _with_args
4+
from .observer import MinMaxObserver, _with_args
55

66
class FakeQuantize(Module):
77
''' Simulate the quantize and dequantize operations in training time.
@@ -14,7 +14,7 @@ class FakeQuantize(Module):
1414
'''
1515

1616
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
17-
quant_min=0, quant_max=255):
17+
quant_min=0, quant_max=255, reduce_range=False):
1818
super(FakeQuantize, self).__init__()
1919
assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound'
2020
assert quant_min <= quant_max, \
@@ -24,36 +24,45 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
2424
self.qscheme = qscheme
2525
self.quant_min = quant_min
2626
self.quant_max = quant_max
27-
self.enabled = True
28-
self.observer = default_observer(dtype=dtype, qscheme=qscheme)
27+
self.fake_quant_enabled = True
28+
self.observer_enabled = True
29+
self.observer = MinMaxObserver.with_args(dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)()
2930
self.scale = None
3031
self.zero_point = None
3132

32-
def enable(self, enabled=True):
33-
self.enabled = enabled
33+
def enable_fake_quant(self, enabled=True):
34+
self.fake_quant_enabled = enabled
3435
return self
3536

36-
def disable(self):
37-
return self.enable(False)
37+
def disable_fake_quant(self):
38+
return self.enable_fake_quant(False)
39+
40+
def enable_observer(self, enabled=True):
41+
self.observer_enabled = enabled
42+
43+
def disable_observer(self):
44+
return self.enable_observer(False)
3845

3946
def calculate_qparams(self):
4047
return self.observer.calculate_qparams()
4148

4249
def forward(self, X):
43-
if self.enabled:
44-
self.observer(X)
50+
if self.observer_enabled:
51+
X = self.observer(X)
4552
scale, zero_point = self.calculate_qparams()
4653
self.scale, self.zero_point = float(scale), int(zero_point)
47-
X = torch.fake_quantize_per_tensor_affine(
48-
X, self.scale, self.zero_point, self.quant_min,
49-
self.quant_max)
54+
if self.fake_quant_enabled:
55+
X = torch.fake_quantize_per_tensor_affine(X, self.scale, self.zero_point, self.quant_min, self.quant_max)
5056
return X
5157

5258
with_args = classmethod(_with_args)
5359

54-
default_fake_quant = FakeQuantize
60+
def extra_repr(self):
61+
return 'fake_quant_enabled={}, observer_enabled={},\
62+
scale={}, zero_point={}'.format(
63+
self.fake_quant_enabled, self.observer_enabled,
64+
self.scale, self.zero_point)
5565

56-
default_weight_fake_quant = FakeQuantize.with_args(dtype=torch.qint8,
57-
qscheme=torch.per_tensor_symmetric,
58-
quant_min=-128,
59-
quant_max=127)
66+
default_fake_quant = FakeQuantize
67+
default_weight_fake_quant = FakeQuantize.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric,
68+
quant_min=-128, quant_max=127)

0 commit comments

Comments
 (0)