From beff9a29fc2cdae9ad8a2d6faa8d8e97e2b3c3b2 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Sat, 9 Apr 2022 20:15:35 -0700 Subject: [PATCH] [Quant] Add FX support in quantization examples Summary: Previously, the quantization examples use only eager mode quantization. This commit adds support for FX mode quantization as well. TODO: provide accuracy comparison. Test Plan: python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL' model: $MODEL is one of googlenet, inception_v3, resnet18, resnet50, resnext101_32x8d, shufflenet_v2_x0_5 and shufflenet_v2_x1_0 Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo [ghstack-poisoned] --- .../classification/train_quantization.py | 43 +++++++++++++++---- torchvision/models/quantization/googlenet.py | 3 +- torchvision/models/quantization/inception.py | 3 +- .../models/quantization/mobilenetv2.py | 3 +- .../models/quantization/mobilenetv3.py | 19 ++++++-- torchvision/models/quantization/resnet.py | 3 +- .../models/quantization/shufflenetv2.py | 3 +- torchvision/models/quantization/utils.py | 25 +++++++---- 8 files changed, 76 insertions(+), 26 deletions(-) diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index c0e5af1dcfc..cc5c1174782 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -5,6 +5,7 @@ import torch import torch.ao.quantization +import torch.ao.quantization.quantize_fx import torch.utils.data import torchvision import utils @@ -46,13 +47,19 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model - model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) + model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only, fx_mode=args.fx) model.to(device) if not (args.test_only or args.post_training_quantize): - model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend) - torch.ao.quantization.prepare_qat(model, inplace=True) + qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend) + if args.fx: + qconfig_dict = {"": qconfig} + model = torch.ao.quantization.quantize_fx.fuse_fx(model) + model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) + else: + model.qconfig = qconfig + model.fuse_model(is_qat=True) + torch.ao.quantization.prepare_qat(model, inplace=True) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -84,13 +91,22 @@ def main(args): ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True ) model.eval() - model.fuse_model(is_qat=False) - model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend) - torch.ao.quantization.prepare(model, inplace=True) + qconfig = torch.ao.quantization.get_default_qconfig(args.backend) + if args.fx: + qconfig_dict = {"": qconfig} + model = torch.ao.quantization.quantize_fx.fuse_fx(model) + model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict) + else: + model.qconfig = qconfig + model.fuse_model(is_qat=False) + torch.ao.quantization.prepare(model, inplace=True) # Calibrate first print("Calibrating") evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) - torch.ao.quantization.convert(model, inplace=True) + if args.fx: + model = torch.ao.quantization.quantize_fx.convert_fx(model) + else: + torch.ao.quantization.convert(model, inplace=True) if args.output_dir: print("Saving quantized model") if utils.is_main_process(): @@ -125,7 +141,10 @@ def main(args): quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model.eval() quantized_eval_model.to(torch.device("cpu")) - torch.ao.quantization.convert(quantized_eval_model, inplace=True) + if args.fx: + quantized_eval_model = torch.ao.quantization.quantize_fx.convert_fx(quantized_eval_model) + else: + torch.ao.quantization.convert(quantized_eval_model, inplace=True) print("Evaluate Quantized model") evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu")) @@ -233,6 +252,12 @@ def get_args_parser(add_help=True): help="Post training quantize the model", action="store_true", ) + parser.add_argument( + "--fx", + dest="fx", + help="Use FX quantization", + action="store_true", + ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 1794c834eea..cdfa3db80b3 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -170,11 +170,12 @@ def googlenet( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") + fx_mode = kwargs.pop("fx_mode", False) model = QuantizableGoogLeNet(**kwargs) _replace_relu(model) if quantize: - quantize_model(model, backend) + quantize_model(model, backend, fx_mode) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index ff5c9a37365..97d5abeac6b 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -239,11 +239,12 @@ def inception_v3( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") + fx_mode = kwargs.pop("fx_mode", False) model = QuantizableInception3(**kwargs) _replace_relu(model) if quantize: - quantize_model(model, backend) + quantize_model(model, backend, fx_mode) if weights is not None: if quantize and not original_aux_logits: diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index d9554e0ba9f..c1114f7b7fb 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -125,11 +125,12 @@ def mobilenet_v2( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "qnnpack") + fx_mode = kwargs.pop("fx_mode", False) model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) if quantize: - quantize_model(model, backend) + quantize_model(model, backend, fx_mode) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 88907ec210a..4cab30ad907 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -4,6 +4,7 @@ import torch from torch import nn, Tensor from torch.ao.quantization import QuantStub, DeQuantStub +import torch.ao.quantization.quantize_fx from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ...transforms._presets import ImageClassification, InterpolationMode @@ -135,20 +136,30 @@ def _mobilenet_v3_model( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "qnnpack") + fx_mode = kwargs.pop("fx_mode", False) model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) if quantize: - model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) - torch.ao.quantization.prepare_qat(model, inplace=True) + qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) + if fx_mode: + qconfig_dict = {"": qconfig} + model = torch.ao.quantization.quantize_fx.fuse_fx(model) + model = torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) + else: + model.qconfig = qconfig + model.fuse_model(is_qat=True) + torch.ao.quantization.prepare_qat(model, inplace=True) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) if quantize: - torch.ao.quantization.convert(model, inplace=True) + if fx_mode: + model = torch.ao.quantization.quantize_fx.convert_fx(model) + else: + torch.ao.quantization.convert(model, inplace=True) model.eval() return model diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index a781f320000..bd0a68c689b 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -134,11 +134,12 @@ def _resnet( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") + fx_mode = kwargs.pop("fx_mode", False) model = QuantizableResNet(block, layers, **kwargs) _replace_relu(model) if quantize: - quantize_model(model, backend) + quantize_model(model, backend, fx_mode) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 1f4f1890e07..6b6cb81c7e0 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -89,11 +89,12 @@ def _shufflenetv2( if "backend" in weights.meta: _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) backend = kwargs.pop("backend", "fbgemm") + fx_mode = kwargs.pop("fx_mode", False) model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) _replace_relu(model) if quantize: - quantize_model(model, backend) + quantize_model(model, backend, fx_mode) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/models/quantization/utils.py b/torchvision/models/quantization/utils.py index a21e2af8e01..27a6f9fb8bc 100644 --- a/torchvision/models/quantization/utils.py +++ b/torchvision/models/quantization/utils.py @@ -2,6 +2,7 @@ import torch from torch import nn +import torch.ao.quantization.quantize_fx def _replace_relu(module: nn.Module) -> None: @@ -18,7 +19,7 @@ def _replace_relu(module: nn.Module) -> None: module._modules[key] = value -def quantize_model(model: nn.Module, backend: str) -> None: +def quantize_model(model: nn.Module, backend: str, fx_mode: bool) -> None: _dummy_input_data = torch.rand(1, 3, 299, 299) if backend not in torch.backends.quantized.supported_engines: raise RuntimeError("Quantized backend not supported ") @@ -26,20 +27,28 @@ def quantize_model(model: nn.Module, backend: str) -> None: model.eval() # Make sure that weight qconfig matches that of the serialized models if backend == "fbgemm": - model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_per_channel_weight_observer, ) elif backend == "qnnpack": - model.qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] + qconfig = torch.ao.quantization.QConfig( # type: ignore[assignment] activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer ) - # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 - model.fuse_model() # type: ignore[operator] - torch.ao.quantization.prepare(model, inplace=True) - model(_dummy_input_data) - torch.ao.quantization.convert(model, inplace=True) + if fx_mode: + qconfig_dict = {"": qconfig} + model = torch.ao.quantization.quantize_fx.fuse_fx(model) + model = torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_dict) + model(_dummy_input_data) + model = torch.ao.quantization.quantize_fx.convert_fx(model) + else: + model.qconfig = qconfig + # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 + model.fuse_model() # type: ignore[operator] + torch.ao.quantization.prepare(model, inplace=True) + model(_dummy_input_data) + torch.ao.quantization.convert(model, inplace=True) def _fuse_modules(