Skip to content

Commit 7cfdc44

Browse files
committed
[Quant] Add FX support in quantization examples
Summary: Previously, the quantization examples use only eager mode quantization. This commit adds support for FX mode quantization as well. Test Plan: ``` python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\ --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization" python train_quantization.py --device="cpu" --post-training-quantize --backend="fbgemm"\ --model="$MODEL" --weights="IMAGENET1K_V1" --quantization-workflow-type="eager_mode_quantization" python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\ --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\ --quantization-workflow-type="eager_mode_quantization" python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v2"\ --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.0001 --weight-decay=0.0001\ --quantization-workflow-type="fx_graph_mode_quantization" python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\ --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\ --quantization-workflow-type="eager_mode_quantization" python train_quantization.py --device="cuda" --backend="qnnpack" --model="mobilenet_v3_large"\ --epochs=10 --workers=64 --weights="IMAGENET1K_V1" --lr=0.001 --weight-decay=0.00001\ --quantization-workflow-type="fx_graph_mode_quantization" ``` Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo ghstack-source-id: 6308d7b Pull Request resolved: #5797
1 parent 3122ea1 commit 7cfdc44

File tree

9 files changed

+151
-62
lines changed

9 files changed

+151
-62
lines changed

docs/source/models.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,13 +486,13 @@ Model Acc@1 Acc@5
486486
================================ ============= =============
487487
MobileNet V2 71.658 90.150
488488
MobileNet V3 Large 73.004 90.858
489-
ShuffleNet V2 x1.0 68.360 87.582
490-
ShuffleNet V2 x0.5 57.972 79.780
491-
ResNet 18 69.494 88.882
492-
ResNet 50 75.920 92.814
493-
ResNext 101 32x8d 78.986 94.480
489+
ShuffleNet V2 x1.0 67.886 87.332
490+
ShuffleNet V2 x0.5 57.784 79.458
491+
ResNet 18 69.458 88.902
492+
ResNet 50 75.712 92.782
493+
ResNext 101 32x8d 78.982 94.422
494494
Inception V3 77.176 93.354
495-
GoogleNet 69.826 89.404
495+
GoogleNet 69.598 89.398
496496
================================ ============= =============
497497

498498

references/classification/train_quantization.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
import torch
77
import torch.ao.quantization
8+
import torch.ao.quantization.quantize_fx
89
import torch.utils.data
910
import torchvision
1011
import utils
1112
from torch import nn
13+
from torchvision.models.quantization.utils import QuantizationWorkflowType
1214
from train import train_one_epoch, evaluate, load_data
1315

1416

@@ -22,6 +24,15 @@ def main(args):
2224
if args.post_training_quantize and args.distributed:
2325
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
2426

27+
# Validate quantization workflow type
28+
all_quantization_workflow_types = [t.value for t in QuantizationWorkflowType]
29+
if args.quantization_workflow_type not in all_quantization_workflow_types:
30+
raise RuntimeError(
31+
"Unknown quantization workflow type '%s', must be one of: %s"
32+
% (args.quantization_workflow_type, all_quantization_workflow_types)
33+
)
34+
quantization_workflow_type = QuantizationWorkflowType(args.quantization_workflow_type)
35+
2536
# Set backend engine to ensure that quantized model runs on the correct kernels
2637
if args.backend not in torch.backends.quantized.supported_engines:
2738
raise RuntimeError("Quantized backend not supported: " + str(args.backend))
@@ -46,13 +57,21 @@ def main(args):
4657

4758
print("Creating model", args.model)
4859
# when training quantized models, we always start from a pre-trained fp32 reference model
49-
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
60+
model = torchvision.models.quantization.__dict__[args.model](
61+
weights=args.weights,
62+
quantize=args.test_only,
63+
quantization_workflow_type=quantization_workflow_type,
64+
)
5065
model.to(device)
5166

5267
if not (args.test_only or args.post_training_quantize):
53-
model.fuse_model(is_qat=True)
54-
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
55-
torch.ao.quantization.prepare_qat(model, inplace=True)
68+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
69+
qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(args.backend)
70+
model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
71+
else:
72+
model.fuse_model(is_qat=True)
73+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
74+
torch.ao.quantization.prepare_qat(model, inplace=True)
5675

5776
if args.distributed and args.sync_bn:
5877
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -84,13 +103,20 @@ def main(args):
84103
ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
85104
)
86105
model.eval()
87-
model.fuse_model(is_qat=False)
88-
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
89-
torch.ao.quantization.prepare(model, inplace=True)
106+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
107+
qconfig_dict = torch.ao.quantization.get_default_qconfig_dict(args.backend)
108+
model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict)
109+
else:
110+
model.fuse_model(is_qat=False)
111+
model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
112+
torch.ao.quantization.prepare(model, inplace=True)
90113
# Calibrate first
91114
print("Calibrating")
92115
evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
93-
torch.ao.quantization.convert(model, inplace=True)
116+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
117+
model = torch.ao.quantization.quantize_fx.convert_fx(model)
118+
else:
119+
torch.ao.quantization.convert(model, inplace=True)
94120
if args.output_dir:
95121
print("Saving quantized model")
96122
if utils.is_main_process():
@@ -125,7 +151,10 @@ def main(args):
125151
quantized_eval_model = copy.deepcopy(model_without_ddp)
126152
quantized_eval_model.eval()
127153
quantized_eval_model.to(torch.device("cpu"))
128-
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
154+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
155+
quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model)
156+
else:
157+
torch.ao.quantization.convert(quantized_eval_model, inplace=True)
129158

130159
print("Evaluate Quantized model")
131160
evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
@@ -233,6 +262,12 @@ def get_args_parser(add_help=True):
233262
help="Post training quantize the model",
234263
action="store_true",
235264
)
265+
parser.add_argument(
266+
"--quantization-workflow-type",
267+
default="eager_mode",
268+
type=str,
269+
help="The quantization workflow type to use, either 'eager_mode' (default) or 'fx_graph_mode'",
270+
)
236271

237272
# distributed training parameters
238273
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")

torchvision/models/quantization/googlenet.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .._meta import _IMAGENET_CATEGORIES
1313
from .._utils import handle_legacy_interface, _ovewrite_named_param
1414
from ..googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, GoogLeNet_Weights
15-
from .utils import _fuse_modules, _replace_relu, quantize_model
15+
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType
1616

1717

1818
__all__ = [
@@ -170,11 +170,16 @@ def googlenet(
170170
if "backend" in weights.meta:
171171
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
172172
backend = kwargs.pop("backend", "fbgemm")
173+
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)
174+
175+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
176+
model = GoogLeNet(**kwargs)
177+
else:
178+
model = QuantizableGoogLeNet(**kwargs)
179+
_replace_relu(model)
173180

174-
model = QuantizableGoogLeNet(**kwargs)
175-
_replace_relu(model)
176181
if quantize:
177-
quantize_model(model, backend)
182+
model = quantize_model(model, backend, quantization_workflow_type)
178183

179184
if weights is not None:
180185
model.load_state_dict(weights.get_state_dict(progress=progress))

torchvision/models/quantization/inception.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import torch.nn.functional as F
88
from torch import Tensor
99
from torchvision.models import inception as inception_module
10-
from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights
10+
from torchvision.models.inception import Inception3, InceptionOutputs, Inception_V3_Weights
1111

1212
from ...transforms._presets import ImageClassification, InterpolationMode
1313
from .._api import WeightsEnum, Weights
1414
from .._meta import _IMAGENET_CATEGORIES
1515
from .._utils import handle_legacy_interface, _ovewrite_named_param
16-
from .utils import _fuse_modules, _replace_relu, quantize_model
16+
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType
1717

1818

1919
__all__ = [
@@ -239,11 +239,16 @@ def inception_v3(
239239
if "backend" in weights.meta:
240240
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
241241
backend = kwargs.pop("backend", "fbgemm")
242+
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)
243+
244+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
245+
model = Inception3(**kwargs)
246+
else:
247+
model = QuantizableInception3(**kwargs)
248+
_replace_relu(model)
242249

243-
model = QuantizableInception3(**kwargs)
244-
_replace_relu(model)
245250
if quantize:
246-
quantize_model(model, backend)
251+
model = quantize_model(model, backend, quantization_workflow_type)
247252

248253
if weights is not None:
249254
if quantize and not original_aux_logits:

torchvision/models/quantization/mobilenetv2.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .._api import WeightsEnum, Weights
1212
from .._meta import _IMAGENET_CATEGORIES
1313
from .._utils import handle_legacy_interface, _ovewrite_named_param
14-
from .utils import _fuse_modules, _replace_relu, quantize_model
14+
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType
1515

1616

1717
__all__ = [
@@ -125,11 +125,16 @@ def mobilenet_v2(
125125
if "backend" in weights.meta:
126126
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
127127
backend = kwargs.pop("backend", "qnnpack")
128+
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)
129+
130+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
131+
model = MobileNetV2(**kwargs)
132+
else:
133+
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
134+
_replace_relu(model)
128135

129-
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
130-
_replace_relu(model)
131136
if quantize:
132-
quantize_model(model, backend)
137+
model = quantize_model(model, backend, quantization_workflow_type)
133138

134139
if weights is not None:
135140
model.load_state_dict(weights.get_state_dict(progress=progress))

torchvision/models/quantization/mobilenetv3.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, List, Optional, Union
33

44
import torch
5+
import torch.ao.quantization.quantize_fx
56
from torch import nn, Tensor
67
from torch.ao.quantization import QuantStub, DeQuantStub
78

@@ -17,7 +18,7 @@
1718
_mobilenet_v3_conf,
1819
MobileNet_V3_Large_Weights,
1920
)
20-
from .utils import _fuse_modules, _replace_relu
21+
from .utils import _fuse_modules, _replace_relu, QuantizationWorkflowType
2122

2223

2324
__all__ = [
@@ -135,20 +136,32 @@ def _mobilenet_v3_model(
135136
if "backend" in weights.meta:
136137
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
137138
backend = kwargs.pop("backend", "qnnpack")
139+
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)
138140

139-
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
140-
_replace_relu(model)
141+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
142+
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
143+
else:
144+
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
145+
_replace_relu(model)
141146

142147
if quantize:
143-
model.fuse_model(is_qat=True)
144-
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
145-
torch.ao.quantization.prepare_qat(model, inplace=True)
148+
# TODO: This shouldn't be QAT?
149+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
150+
qconfig_dict = torch.ao.quantization.get_default_qat_qconfig_dict(backend)
151+
model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
152+
else:
153+
model.fuse_model(is_qat=True)
154+
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
155+
torch.ao.quantization.prepare_qat(model, inplace=True)
146156

147157
if weights is not None:
148158
model.load_state_dict(weights.get_state_dict(progress=progress))
149159

150160
if quantize:
151-
torch.ao.quantization.convert(model, inplace=True)
161+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
162+
model = torch.ao.quantization.quantize_fx.convert_fx(model)
163+
else:
164+
torch.ao.quantization.convert(model, inplace=True)
152165
model.eval()
153166

154167
return model

torchvision/models/quantization/resnet.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .._api import WeightsEnum, Weights
1818
from .._meta import _IMAGENET_CATEGORIES
1919
from .._utils import handle_legacy_interface, _ovewrite_named_param
20-
from .utils import _fuse_modules, _replace_relu, quantize_model
20+
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType
2121

2222

2323
__all__ = [
@@ -134,11 +134,16 @@ def _resnet(
134134
if "backend" in weights.meta:
135135
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
136136
backend = kwargs.pop("backend", "fbgemm")
137+
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)
138+
139+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
140+
model = ResNet(block, layers, **kwargs)
141+
else:
142+
model = QuantizableResNet(block, layers, **kwargs)
143+
_replace_relu(model)
137144

138-
model = QuantizableResNet(block, layers, **kwargs)
139-
_replace_relu(model)
140145
if quantize:
141-
quantize_model(model, backend)
146+
model = quantize_model(model, backend, quantization_workflow_type)
142147

143148
if weights is not None:
144149
model.load_state_dict(weights.get_state_dict(progress=progress))

torchvision/models/quantization/shufflenetv2.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from .._api import WeightsEnum, Weights
1111
from .._meta import _IMAGENET_CATEGORIES
1212
from .._utils import handle_legacy_interface, _ovewrite_named_param
13-
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
14-
from .utils import _fuse_modules, _replace_relu, quantize_model
13+
from ..shufflenetv2 import ShuffleNetV2, ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
14+
from .utils import _fuse_modules, _replace_relu, quantize_model, QuantizationWorkflowType
1515

1616

1717
__all__ = [
@@ -40,7 +40,7 @@ def forward(self, x: Tensor) -> Tensor:
4040
return out
4141

4242

43-
class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
43+
class QuantizableShuffleNetV2(ShuffleNetV2):
4444
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
4545
def __init__(self, *args: Any, **kwargs: Any) -> None:
4646
super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs) # type: ignore[misc]
@@ -89,11 +89,16 @@ def _shufflenetv2(
8989
if "backend" in weights.meta:
9090
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
9191
backend = kwargs.pop("backend", "fbgemm")
92+
quantization_workflow_type = kwargs.pop("quantization_workflow_type", QuantizationWorkflowType.EAGER_MODE)
93+
94+
if quantization_workflow_type == QuantizationWorkflowType.FX_GRAPH_MODE:
95+
model = ShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
96+
else:
97+
model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
98+
_replace_relu(model)
9299

93-
model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
94-
_replace_relu(model)
95100
if quantize:
96-
quantize_model(model, backend)
101+
model = quantize_model(model, backend, quantization_workflow_type)
97102

98103
if weights is not None:
99104
model.load_state_dict(weights.get_state_dict(progress=progress))

0 commit comments

Comments
 (0)