Skip to content

Commit d9a6386

Browse files
committed
[Quant] Add FX support in quantization examples
Summary: Previously, the quantization examples use only eager mode quantization. This commit adds support for FX mode quantization as well. TODO: provide accuracy comparison. Test Plan: python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL' model: $MODEL is one of googlenet, inception_v3, resnet18, resnet50, resnext101_32x8d, shufflenet_v2_x0_5 and shufflenet_v2_x1_0 Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 7e77c46 Pull Request resolved: #5797
1 parent 467b841 commit d9a6386

File tree

1 file changed

+51
-9
lines changed

1 file changed

+51
-9
lines changed

references/classification/train_quantization.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
import datetime
33
import os
44
import time
5+
from enum import Enum
56

67
import torch
78
import torch.ao.quantization
9+
import torch.ao.quantization.quantize_fx
810
import torch.utils.data
911
import torchvision
1012
import utils
1113
from torch import nn
1214
from train import train_one_epoch, evaluate, load_data
1315

1416

17+
class QuantizationWorkflowType(Enum):
18+
EAGER_MODE_QUANTIZATION = 1
19+
FX_GRAPH_MODE_QUANTIZATION = 2
20+
21+
1522
def main(args):
1623
if args.output_dir:
1724
utils.mkdir(args.output_dir)
@@ -22,6 +29,15 @@ def main(args):
2229
if args.post_training_quantize and args.distributed:
2330
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
2431

32+
# Validate quantization workflow type
33+
quantization_workflow_type = args.quantization_workflow_type.upper()
34+
if quantization_workflow_type not in QuantizationWorkflowType.__members__:
35+
raise RuntimeError(
36+
"Unknown workflow type '%s', please choose from: %s"
37+
% (args.quantization_workflow_type, str(tuple([t.lower() for t in QuantizationWorkflowType.__members__])))
38+
)
39+
use_fx_graph_mode_quantization = quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE_QUANTIZATION
40+
2541
# Set backend engine to ensure that quantized model runs on the correct kernels
2642
if args.backend not in torch.backends.quantized.supported_engines:
2743
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
@@ -45,14 +61,23 @@ def main(args):
4561
)
4662

4763
print("Creating model", args.model)
64+
if use_fx_graph_mode_quantization:
65+
model_namespace = torchvision.models
66+
else:
67+
model_namespace = torchvision.models.quantization
4868
# when training quantized models, we always start from a pre-trained fp32 reference model
49-
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
69+
model = model_namespace.__dict__[args.model](weights=args.weights, quantize=args.test_only)
5070
model.to(device)
5171

5272
if not (args.test_only or args.post_training_quantize):
53-
model.fuse_model(is_qat=True)
54-
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
55-
torch.ao.quantization.prepare_qat(model, inplace=True)
73+
qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
74+
if use_fx_graph_mode_quantization:
75+
qconfig_dict = {"": qconfig}
76+
model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
77+
else:
78+
model.fuse_model(is_qat=True)
79+
model.qconfig = qconfig
80+
torch.ao.quantization.prepare_qat(model, inplace=True)
5681

5782
if args.distributed and args.sync_bn:
5883
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -84,13 +109,21 @@ def main(args):
84109
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
85110
)
86111
model.eval()
87-
model.fuse_model(is_qat=False)
88-
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
89-
torch.ao.quantization.prepare(model, inplace=True)
112+
qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
113+
if use_fx_graph_mode_quantization:
114+
qconfig_dict = {"": qconfig}
115+
model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict)
116+
else:
117+
model.fuse_model(is_qat=False)
118+
model.qconfig = qconfig
119+
torch.ao.quantization.prepare(model, inplace=True)
90120
# Calibrate first
91121
print("Calibrating")
92122
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
93-
torch.ao.quantization.convert(model, inplace=True)
123+
if use_fx_graph_mode_quantization:
124+
model = torch.ao.quantization.quantize_fx.convert_fx(model)
125+
else:
126+
torch.ao.quantization.convert(model, inplace=True)
94127
if args.output_dir:
95128
print("Saving quantized model")
96129
if utils.is_main_process():
@@ -125,7 +158,10 @@ def main(args):
125158
quantized_eval_model = copy.deepcopy(model_without_ddp)
126159
quantized_eval_model.eval()
127160
quantized_eval_model.to(torch.device("cpu"))
128-
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
161+
if use_fx_graph_mode_quantization:
162+
quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model)
163+
else:
164+
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
129165

130166
print("Evaluate Quantized model")
131167
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
@@ -233,6 +269,12 @@ def get_args_parser(add_help=True):
233269
help="Post training quantize the model",
234270
action="store_true",
235271
)
272+
parser.add_argument(
273+
"--quantization-workflow-type",
274+
default="eager_mode_quantization",
275+
type=str,
276+
help="The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'",
277+
)
236278

237279
# distributed training parameters
238280
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")

0 commit comments

Comments
 (0)