@@ -199,6 +199,77 @@ def test_linear(self):
199
199
(16 , 5 )
200
200
)
201
201
202
+ @override_qengines
203
+ def test_int16_reference_module (self ):
204
+
205
+ class RefM (torch .nn .Module ):
206
+ def __init__ (self ):
207
+ super ().__init__ ()
208
+ self .conv = nn .ConvTranspose2d (1 , 1 , 1 )
209
+ self .quant1 = QuantStub ()
210
+ self .dequant1 = DeQuantStub ()
211
+ self .quant2 = QuantStub ()
212
+ self .dequant2 = DeQuantStub ()
213
+
214
+ def forward (self , x ):
215
+ x = self .quant1 (x )
216
+ x = self .dequant1 (x )
217
+ x = self .conv (x )
218
+ x = self .quant2 (x )
219
+ x = self .dequant2 (x )
220
+ return x
221
+
222
+
223
+ input_size = (16 , 1 , 10 , 10 )
224
+ data = torch .randn (* input_size , dtype = torch .float )
225
+
226
+ original_ref_m = RefM ()
227
+ rand_w = torch .randn_like (original_ref_m .conv .weight )
228
+ rand_b = torch .randn_like (original_ref_m .conv .bias )
229
+ original_ref_m .conv .weight = torch .nn .Parameter (rand_w , requires_grad = False )
230
+ original_ref_m .conv .bias = torch .nn .Parameter (rand_b , requires_grad = False )
231
+
232
+ qengine = torch .backends .quantized .engine
233
+ if qengine not in supported_qengines :
234
+ return
235
+ from torch .ao .quantization .observer import MovingAverageMinMaxObserver
236
+
237
+ weight_obs = MovingAverageMinMaxObserver .with_args (
238
+ dtype = torch .qint32 ,
239
+ # set qmin and qmax to represent qint16
240
+ quant_min = - 1 * (2 ** 15 ),
241
+ quant_max = (2 ** 15 ) - 1 ,
242
+ qscheme = torch .per_tensor_symmetric ,
243
+ )
244
+ act_obs = MovingAverageMinMaxObserver .with_args (
245
+ dtype = torch .qint32 ,
246
+ quant_min = - 1 * (2 ** 15 ),
247
+ quant_max = (2 ** 15 ) - 1 ,
248
+ )
249
+ custom_qconfig = QConfig (activation = act_obs , weight = weight_obs )
250
+
251
+ # quantize the reference model
252
+ original_ref_m .eval ()
253
+ original_ref_m .qconfig = custom_qconfig
254
+
255
+ ref_m = prepare (original_ref_m )
256
+ # calibration
257
+ ref_m (torch .randn (* input_size , dtype = torch .float ))
258
+
259
+ ref_m = convert (ref_m , is_reference = True )
260
+
261
+ myobs = MovingAverageMinMaxObserver (averaging_constant = 0.5 ,
262
+ dtype = torch .qint32 ,
263
+ # set qmin and qmax to represent qint16
264
+ quant_min = - 1 * (2 ** 15 ),
265
+ quant_max = (2 ** 15 ) - 1 ,
266
+ qscheme = torch .per_tensor_symmetric ,
267
+ )
268
+ result = myobs (rand_w )
269
+ qparams = myobs .calculate_qparams ()
270
+ self .assertEqual (ref_m .conv .weight_scale , qparams [0 ])
271
+
272
+
202
273
def _test_activation_op_impl (
203
274
self , float_module_class , quantized_module_class , extra_module_kwargs ):
204
275
""" Implementation for testing common activation ops like leaky relu
0 commit comments