Skip to content

Commit 8d872b4

Browse files
committed
[Quant] Add FX support in quantization examples
Summary: Previously, the quantization examples use only eager mode quantization. This commit adds support for FX mode quantization as well. Test Plan: ``` python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\ --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization" python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\ --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization" python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\ --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\ --quantization-workflow-type="eager_mode_quantization" python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\ --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\ --quantization-workflow-type="fx_graph_mode_quantization" python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\ --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\ --quantization-workflow-type="eager_mode_quantization" python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\ --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\ --quantization-workflow-type="fx_graph_mode_quantization" ``` Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 6308d7b Pull Request resolved: #5797
1 parent 467b841 commit 8d872b4

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

docs/source/models.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,13 +486,13 @@ Model Acc@1 Acc@5
486486
================================ ============= =============
487487
MobileNet V2 71.658 90.150
488488
MobileNet V3 Large 73.004 90.858
489-
ShuffleNet V2 x1.0 68.360 87.582
490-
ShuffleNet V2 x0.5 57.972 79.780
491-
ResNet 18 69.494 88.882
492-
ResNet 50 75.920 92.814
493-
ResNext 101 32x8d 78.986 94.480
489+
ShuffleNet V2 x1.0 67.886 87.332
490+
ShuffleNet V2 x0.5 57.784 79.458
491+
ResNet 18 69.458 88.902
492+
ResNet 50 75.712 92.782
493+
ResNext 101 32x8d 78.982 94.422
494494
Inception V3 77.176 93.354
495-
GoogleNet 69.826 89.404
495+
GoogleNet 69.598 89.398
496496
================================ ============= =============
497497

498498

references/classification/train_quantization.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
import datetime
33
import os
44
import time
5+
from enum import Enum
56

67
import torch
78
import torch.ao.quantization
9+
import torch.ao.quantization.quantize_fx
810
import torch.utils.data
911
import torchvision
1012
import utils
1113
from torch import nn
1214
from train import train_one_epoch, evaluate, load_data
1315

1416

17+
class QuantizationWorkflowType(Enum):
18+
EAGER_MODE_QUANTIZATION = 1
19+
FX_GRAPH_MODE_QUANTIZATION = 2
20+
21+
1522
def main(args):
1623
if args.output_dir:
1724
utils.mkdir(args.output_dir)
@@ -22,6 +29,17 @@ def main(args):
2229
if args.post_training_quantize and args.distributed:
2330
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
2431

32+
# Validate quantization workflow type
33+
quantization_workflow_type = args.quantization_workflow_type.upper()
34+
if quantization_workflow_type not in QuantizationWorkflowType.__members__:
35+
raise RuntimeError(
36+
"Unknown workflow type '%s', please choose from: %s"
37+
% (args.quantization_workflow_type, str(tuple([t.lower() for t in QuantizationWorkflowType.__members__])))
38+
)
39+
use_fx_graph_mode_quantization = (
40+
QuantizationWorkflowType[quantization_workflow_type] == QuantizationWorkflowType.FX_GRAPH_MODE_QUANTIZATION
41+
)
42+
2543
# Set backend engine to ensure that quantized model runs on the correct kernels
2644
if args.backend not in torch.backends.quantized.supported_engines:
2745
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
@@ -45,14 +63,22 @@ def main(args):
4563
)
4664

4765
print("Creating model", args.model)
66+
if use_fx_graph_mode_quantization:
67+
model_namespace = torchvision.models
68+
else:
69+
model_namespace = torchvision.models.quantization
4870
# when training quantized models, we always start from a pre-trained fp32 reference model
49-
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
71+
model = model_namespace.__dict__[args.model](weights=args.weights, quantize=args.test_only)
5072
model.to(device)
5173

5274
if not (args.test_only or args.post_training_quantize):
53-
model.fuse_model(is_qat=True)
54-
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
55-
torch.ao.quantization.prepare_qat(model, inplace=True)
75+
if use_fx_graph_mode_quantization:
76+
qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(args.backend)
77+
model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
78+
else:
79+
model.fuse_model(is_qat=True)
80+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
81+
torch.ao.quantization.prepare_qat(model, inplace=True)
5682

5783
if args.distributed and args.sync_bn:
5884
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -84,13 +110,20 @@ def main(args):
84110
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
85111
)
86112
model.eval()
87-
model.fuse_model(is_qat=False)
88-
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
89-
torch.ao.quantization.prepare(model, inplace=True)
113+
if use_fx_graph_mode_quantization:
114+
qconfig_dict = torch.ao.quantization.get_default_qconfig_dict(args.backend)
115+
model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict)
116+
else:
117+
model.fuse_model(is_qat=False)
118+
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
119+
torch.ao.quantization.prepare(model, inplace=True)
90120
# Calibrate first
91121
print("Calibrating")
92122
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
93-
torch.ao.quantization.convert(model, inplace=True)
123+
if use_fx_graph_mode_quantization:
124+
model = torch.ao.quantization.quantize_fx.convert_fx(model)
125+
else:
126+
torch.ao.quantization.convert(model, inplace=True)
94127
if args.output_dir:
95128
print("Saving quantized model")
96129
if utils.is_main_process():
@@ -125,7 +158,10 @@ def main(args):
125158
quantized_eval_model = copy.deepcopy(model_without_ddp)
126159
quantized_eval_model.eval()
127160
quantized_eval_model.to(torch.device("cpu"))
128-
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
161+
if use_fx_graph_mode_quantization:
162+
quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model)
163+
else:
164+
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
129165

130166
print("Evaluate Quantized model")
131167
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
@@ -233,6 +269,12 @@ def get_args_parser(add_help=True):
233269
help="Post training quantize the model",
234270
action="store_true",
235271
)
272+
parser.add_argument(
273+
"--quantization-workflow-type",
274+
default="eager_mode_quantization",
275+
type=str,
276+
help="The quantization workflow type to use, either 'eager_mode_quantization' (default) or 'fx_graph_mode_quantization'",
277+
)
236278

237279
# distributed training parameters
238280
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")

0 commit comments

Comments
 (0)