4
4
import time
5
5
6
6
import torch
7
- import torch .quantization
7
+ import torch .ao . quantization
8
8
import torch .utils .data
9
9
import torchvision
10
10
import utils
@@ -62,8 +62,8 @@ def main(args):
62
62
63
63
if not (args .test_only or args .post_training_quantize ):
64
64
model .fuse_model ()
65
- model .qconfig = torch .quantization .get_default_qat_qconfig (args .backend )
66
- torch .quantization .prepare_qat (model , inplace = True )
65
+ model .qconfig = torch .ao . quantization .get_default_qat_qconfig (args .backend )
66
+ torch .ao . quantization .prepare_qat (model , inplace = True )
67
67
68
68
if args .distributed and args .sync_bn :
69
69
model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -96,12 +96,12 @@ def main(args):
96
96
)
97
97
model .eval ()
98
98
model .fuse_model ()
99
- model .qconfig = torch .quantization .get_default_qconfig (args .backend )
100
- torch .quantization .prepare (model , inplace = True )
99
+ model .qconfig = torch .ao . quantization .get_default_qconfig (args .backend )
100
+ torch .ao . quantization .prepare (model , inplace = True )
101
101
# Calibrate first
102
102
print ("Calibrating" )
103
103
evaluate (model , criterion , data_loader_calibration , device = device , print_freq = 1 )
104
- torch .quantization .convert (model , inplace = True )
104
+ torch .ao . quantization .convert (model , inplace = True )
105
105
if args .output_dir :
106
106
print ("Saving quantized model" )
107
107
if utils .is_main_process ():
@@ -114,8 +114,8 @@ def main(args):
114
114
evaluate (model , criterion , data_loader_test , device = device )
115
115
return
116
116
117
- model .apply (torch .quantization .enable_observer )
118
- model .apply (torch .quantization .enable_fake_quant )
117
+ model .apply (torch .ao . quantization .enable_observer )
118
+ model .apply (torch .ao . quantization .enable_fake_quant )
119
119
start_time = time .time ()
120
120
for epoch in range (args .start_epoch , args .epochs ):
121
121
if args .distributed :
@@ -126,7 +126,7 @@ def main(args):
126
126
with torch .inference_mode ():
127
127
if epoch >= args .num_observer_update_epochs :
128
128
print ("Disabling observer for subseq epochs, epoch = " , epoch )
129
- model .apply (torch .quantization .disable_observer )
129
+ model .apply (torch .ao . quantization .disable_observer )
130
130
if epoch >= args .num_batch_norm_update_epochs :
131
131
print ("Freezing BN for subseq epochs, epoch = " , epoch )
132
132
model .apply (torch .nn .intrinsic .qat .freeze_bn_stats )
@@ -136,7 +136,7 @@ def main(args):
136
136
quantized_eval_model = copy .deepcopy (model_without_ddp )
137
137
quantized_eval_model .eval ()
138
138
quantized_eval_model .to (torch .device ("cpu" ))
139
- torch .quantization .convert (quantized_eval_model , inplace = True )
139
+ torch .ao . quantization .convert (quantized_eval_model , inplace = True )
140
140
141
141
print ("Evaluate Quantized model" )
142
142
evaluate (quantized_eval_model , criterion , data_loader_test , device = torch .device ("cpu" ))
0 commit comments