2
2
import datetime
3
3
import os
4
4
import time
5
+ from enum import Enum
5
6
6
7
import torch
7
8
import torch .ao .quantization
9
+ import torch .ao .quantization .quantize_fx
8
10
import torch .utils .data
9
11
import torchvision
10
12
import utils
11
13
from torch import nn
12
14
from train import train_one_epoch , evaluate , load_data
13
15
14
16
17
+ class QuantizationWorkflowType (Enum ):
18
+ EAGER_MODE_QUANTIZATION = 1
19
+ FX_GRAPH_MODE_QUANTIZATION = 2
20
+
21
+
15
22
def main (args ):
16
23
if args .output_dir :
17
24
utils .mkdir (args .output_dir )
@@ -22,6 +29,17 @@ def main(args):
22
29
if args .post_training_quantize and args .distributed :
23
30
raise RuntimeError ("Post training quantization example should not be performed on distributed mode" )
24
31
32
+ # Validate quantization workflow type
33
+ quantization_workflow_type = args .quantization_workflow_type .upper ()
34
+ if quantization_workflow_type not in QuantizationWorkflowType .__members__ :
35
+ raise RuntimeError (
36
+ "Unknown workflow type '%s', please choose from: %s"
37
+ % (args .quantization_workflow_type , str (tuple ([t .lower () for t in QuantizationWorkflowType .__members__ ])))
38
+ )
39
+ use_fx_graph_mode_quantization = (
40
+ QuantizationWorkflowType [quantization_workflow_type ] == QuantizationWorkflowType .FX_GRAPH_MODE_QUANTIZATION
41
+ )
42
+
25
43
# Set backend engine to ensure that quantized model runs on the correct kernels
26
44
if args .backend not in torch .backends .quantized .supported_engines :
27
45
raise RuntimeError ("Quantized backend not supported: " + str (args .backend ))
@@ -45,14 +63,22 @@ def main(args):
45
63
)
46
64
47
65
print ("Creating model" , args .model )
66
+ if use_fx_graph_mode_quantization :
67
+ model_namespace = torchvision .models
68
+ else :
69
+ model_namespace = torchvision .models .quantization
48
70
# when training quantized models, we always start from a pre-trained fp32 reference model
49
- model = torchvision . models . quantization .__dict__ [args .model ](weights = args .weights , quantize = args .test_only )
71
+ model = model_namespace .__dict__ [args .model ](weights = args .weights , quantize = args .test_only )
50
72
model .to (device )
51
73
52
74
if not (args .test_only or args .post_training_quantize ):
53
- model .fuse_model (is_qat = True )
54
- model .qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
55
- torch .ao .quantization .prepare_qat (model , inplace = True )
75
+ if use_fx_graph_mode_quantization :
76
+ qconfig_dict = torch .ao .quantization .get_default_qat_qconfig_dict (args .backend )
77
+ model = torch .ao .quantization .quantize_fx .prepare_qat_fx (model , qconfig_dict )
78
+ else :
79
+ model .fuse_model (is_qat = True )
80
+ model .qconfig = torch .ao .quantization .get_default_qat_qconfig (args .backend )
81
+ torch .ao .quantization .prepare_qat (model , inplace = True )
56
82
57
83
if args .distributed and args .sync_bn :
58
84
model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
@@ -84,13 +110,20 @@ def main(args):
84
110
ds , batch_size = args .batch_size , shuffle = False , num_workers = args .workers , pin_memory = True
85
111
)
86
112
model .eval ()
87
- model .fuse_model (is_qat = False )
88
- model .qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
89
- torch .ao .quantization .prepare (model , inplace = True )
113
+ if use_fx_graph_mode_quantization :
114
+ qconfig_dict = torch .ao .quantization .get_default_qconfig_dict (args .backend )
115
+ model = torch .ao .quantization .quantize_fx .prepare_fx (model , qconfig_dict )
116
+ else :
117
+ model .fuse_model (is_qat = False )
118
+ model .qconfig = torch .ao .quantization .get_default_qconfig (args .backend )
119
+ torch .ao .quantization .prepare (model , inplace = True )
90
120
# Calibrate first
91
121
print ("Calibrating" )
92
122
evaluate (model , criterion , data_loader_calibration , device = device , print_freq = 1 )
93
- torch .ao .quantization .convert (model , inplace = True )
123
+ if use_fx_graph_mode_quantization :
124
+ model = torch .ao .quantization .quantize_fx .convert_fx (model )
125
+ else :
126
+ torch .ao .quantization .convert (model , inplace = True )
94
127
if args .output_dir :
95
128
print ("Saving quantized model" )
96
129
if utils .is_main_process ():
@@ -125,7 +158,10 @@ def main(args):
125
158
quantized_eval_model = copy .deepcopy (model_without_ddp )
126
159
quantized_eval_model .eval ()
127
160
quantized_eval_model .to (torch .device ("cpu" ))
128
- torch .ao .quantization .convert (quantized_eval_model , inplace = True )
161
+ if use_fx_graph_mode_quantization :
162
+ quantized_eval_model = torch .ao .quantization .quantize_fx .convert_fx (quantized_eval_model )
163
+ else :
164
+ torch .ao .quantization .convert (quantized_eval_model , inplace = True )
129
165
130
166
print ("Evaluate Quantized model" )
131
167
evaluate (quantized_eval_model , criterion , data_loader_test , device = torch .device ("cpu" ))
@@ -233,6 +269,12 @@ def get_args_parser(add_help=True):
233
269
help = "Post training quantize the model" ,
234
270
action = "store_true" ,
235
271
)
272
+ parser .add_argument (
273
+ "--quantization-workflow-type" ,
274
+ default = "eager_mode_quantization" ,
275
+ type = str ,
276
+ help = "The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'" ,
277
+ )
236
278
237
279
# distributed training parameters
238
280
parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
0 commit comments