2
2
import datetime
3
3
import os
4
4
import time
5
+ from enum import Enum
5
6
6
7
import torch
7
8
import torch .ao .quantization
9
+ import torch .ao .quantization .quantize_fx
8
10
import torch .utils .data
9
11
import torchvision
10
12
import utils
11
13
from torch import nn
12
14
from train import train_one_epoch , evaluate , load_data
13
15
14
16
17
+ class QuantizationWorkflowType (Enum ):
18
+ EAGER_MODE_QUANTIZATION = 1
19
+ FX_GRAPH_MODE_QUANTIZATION = 2
20
+
21
+
15
22
def main (args ):
16
23
if args .output_dir :
17
24
utils .mkdir (args .output_dir )
@@ -22,6 +29,15 @@ def main(args):
22
29
if args .post_training_quantize and args .distributed :
23
30
raise RuntimeError ("Post training quantization example should not be performed on distributed mode" )
24
31
32
+ # Validate quantization workflow type
33
+ quantization_workflow_type = args .quantization_workflow_type .upper ()
34
+ if quantization_workflow_type not in QuantizationWorkflowType .__members__ :
35
+ raise RuntimeError (
36
+ "Unknown workflow type '%s', please choose from: %s"
37
+ % (args .quantization_workflow_type , str (tuple ([t .lower () for t in QuantizationWorkflowType .__members__ ])))
38
+ )
39
+ use_fx_graph_mode_quantization = quantization_workflow_type == QuantizationWorkflowType .FX_GRAPH_MODE_QUANTIZATION
40
+
25
41
# Set backend engine to ensure that quantized model runs on the correct kernels
26
42
if args .backend not in torch .backends .quantized .supported_engines :
27
43
raise RuntimeError ("Quantized backend not supported: " + str (args .backend ))
@@ -45,14 +61,23 @@ def main(args):
45
61
)
46
62
47
63
print ("Creating model" , args .model )
64
+ if use_fx_graph_mode_quantization :
65
+ model_namespace = torchvision .models
66
+ else :
67
+ model_namespace = torchvision .models .quantization
48
68
# when training quantized models, we always start from a pre-trained fp32 reference model
49
- model = torchvision . models . quantization .__dict__ [args .model ](weights = args .weights , quantize = args .test_only )
69
+ model = model_namespace .__dict__ [args .model ](weights = args .weights , quantize = args .test_only )
50
70
model .to (device )
51
71
52
72
if not (args .test_only or args .post_training_quantize ):
53
- model .fuse_model (is_qat = True )
54
- model .qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
55
- torch .ao .quantization .prepare_qat (model , inplace = True )
73
+ qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
74
+ if use_fx_graph_mode_quantization :
75
+ qconfig_dict = {"" : qconfig }
76
+ model = torch .ao .quantization .quantize_fx .prepare_qat_fx (model , qconfig_dict )
77
+ else :
78
+ model .fuse_model (is_qat = True )
79
+ model .qconfig = qconfig
80
+ torch .ao .quantization .prepare_qat (model , inplace = True )
56
81
57
82
if args .distributed and args .sync_bn :
58
83
model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -84,13 +109,21 @@ def main(args):
84
109
ds , batch_size = args .batch_size , shuffle = False , num_workers = args .workers , pin_memory = True
85
110
)
86
111
model .eval ()
87
- model .fuse_model (is_qat = False )
88
- model .qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
89
- torch .ao .quantization .prepare (model , inplace = True )
112
+ qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
113
+ if use_fx_graph_mode_quantization :
114
+ qconfig_dict = {"" : qconfig }
115
+ model = torch .ao .quantization .quantize_fx .prepare_fx (model , qconfig_dict )
116
+ else :
117
+ model .fuse_model (is_qat = False )
118
+ model .qconfig = qconfig
119
+ torch .ao .quantization .prepare (model , inplace = True )
90
120
# Calibrate first
91
121
print ("Calibrating" )
92
122
evaluate (model , criterion , data_loader_calibration , device = device , print_freq = 1 )
93
- torch .ao .quantization .convert (model , inplace = True )
123
+ if use_fx_graph_mode_quantization :
124
+ model = torch .ao .quantization .quantize_fx .convert_fx (model )
125
+ else :
126
+ torch .ao .quantization .convert (model , inplace = True )
94
127
if args .output_dir :
95
128
print ("Saving quantized model" )
96
129
if utils .is_main_process ():
@@ -125,7 +158,10 @@ def main(args):
125
158
quantized_eval_model = copy .deepcopy (model_without_ddp )
126
159
quantized_eval_model .eval ()
127
160
quantized_eval_model .to (torch .device ("cpu" ))
128
- torch .ao .quantization .convert (quantized_eval_model , inplace = True )
161
+ if use_fx_graph_mode_quantization :
162
+ quantized_eval_model = torch .ao .quantization .quantize_fx .convert_fx (quantized_eval_model )
163
+ else :
164
+ torch .ao .quantization .convert (quantized_eval_model , inplace = True )
129
165
130
166
print ("Evaluate Quantized model" )
131
167
evaluate (quantized_eval_model , criterion , data_loader_test , device = torch .device ("cpu" ))
@@ -233,6 +269,12 @@ def get_args_parser(add_help=True):
233
269
help = "Post training quantize the model" ,
234
270
action = "store_true" ,
235
271
)
272
+ parser .add_argument (
273
+ "--quantization-workflow-type" ,
274
+ default = "eager_mode_quantization" ,
275
+ type = str ,
276
+ help = "The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'" ,
277
+ )
236
278
237
279
# distributed training parameters
238
280
parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
0 commit comments