1
1
from __future__ import absolute_import , division , print_function , unicode_literals
2
2
import torch
3
3
from torch .nn import Module
4
- from .observer import default_observer , _with_args
4
+ from .observer import MinMaxObserver , _with_args
5
5
6
6
class FakeQuantize (Module ):
7
7
''' Simulate the quantize and dequantize operations in training time.
@@ -14,7 +14,7 @@ class FakeQuantize(Module):
14
14
'''
15
15
16
16
def __init__ (self , dtype = torch .quint8 , qscheme = torch .per_tensor_affine ,
17
- quant_min = 0 , quant_max = 255 ):
17
+ quant_min = 0 , quant_max = 255 , reduce_range = False ):
18
18
super (FakeQuantize , self ).__init__ ()
19
19
assert torch .iinfo (dtype ).min <= quant_min , 'quant_min out of bound'
20
20
assert quant_min <= quant_max , \
@@ -24,36 +24,45 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
24
24
self .qscheme = qscheme
25
25
self .quant_min = quant_min
26
26
self .quant_max = quant_max
27
- self .enabled = True
28
- self .observer = default_observer (dtype = dtype , qscheme = qscheme )
27
+ self .fake_quant_enabled = True
28
+ self .observer_enabled = True
29
+ self .observer = MinMaxObserver .with_args (dtype = dtype , qscheme = qscheme , reduce_range = reduce_range )()
29
30
self .scale = None
30
31
self .zero_point = None
31
32
32
- def enable (self , enabled = True ):
33
- self .enabled = enabled
33
+ def enable_fake_quant (self , enabled = True ):
34
+ self .fake_quant_enabled = enabled
34
35
return self
35
36
36
- def disable (self ):
37
- return self .enable (False )
37
+ def disable_fake_quant (self ):
38
+ return self .enable_fake_quant (False )
39
+
40
+ def enable_observer (self , enabled = True ):
41
+ self .observer_enabled = enabled
42
+
43
+ def disable_observer (self ):
44
+ return self .enable_observer (False )
38
45
39
46
def calculate_qparams (self ):
40
47
return self .observer .calculate_qparams ()
41
48
42
49
def forward (self , X ):
43
- if self .enabled :
44
- self .observer (X )
50
+ if self .observer_enabled :
51
+ X = self .observer (X )
45
52
scale , zero_point = self .calculate_qparams ()
46
53
self .scale , self .zero_point = float (scale ), int (zero_point )
47
- X = torch .fake_quantize_per_tensor_affine (
48
- X , self .scale , self .zero_point , self .quant_min ,
49
- self .quant_max )
54
+ if self .fake_quant_enabled :
55
+ X = torch .fake_quantize_per_tensor_affine (X , self .scale , self .zero_point , self .quant_min , self .quant_max )
50
56
return X
51
57
52
58
with_args = classmethod (_with_args )
53
59
54
- default_fake_quant = FakeQuantize
60
+ def extra_repr (self ):
61
+ return 'fake_quant_enabled={}, observer_enabled={},\
62
+ scale={}, zero_point={}' .format (
63
+ self .fake_quant_enabled , self .observer_enabled ,
64
+ self .scale , self .zero_point )
55
65
56
- default_weight_fake_quant = FakeQuantize .with_args (dtype = torch .qint8 ,
57
- qscheme = torch .per_tensor_symmetric ,
58
- quant_min = - 128 ,
59
- quant_max = 127 )
66
+ default_fake_quant = FakeQuantize
67
+ default_weight_fake_quant = FakeQuantize .with_args (dtype = torch .qint8 , qscheme = torch .per_tensor_symmetric ,
68
+ quant_min = - 128 , quant_max = 127 )
0 commit comments