From 60858f8e49e929e4f46502215a18b0fe39fe0b57 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 15 Oct 2021 14:29:49 +0100 Subject: [PATCH 1/2] Fixing minor issue on typing. --- torchvision/models/quantization/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 596ae56d85b..65a05d8558b 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -110,7 +110,7 @@ def fuse_model(self) -> None: def _resnet( arch: str, - block: Type[Union[BasicBlock, Bottleneck]], + block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], layers: List[int], pretrained: bool, progress: bool, From 81e1f14eb3ca9e718e37d12308b6d9a66e63e50c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 15 Oct 2021 15:09:48 +0100 Subject: [PATCH 2/2] Sample implementation for quantized resnet50. --- torchvision/prototype/models/__init__.py | 1 + .../prototype/models/quantization/__init__.py | 1 + .../prototype/models/quantization/resnet.py | 84 +++++++++++++++++++ 3 files changed, 86 insertions(+) create mode 100644 torchvision/prototype/models/quantization/__init__.py create mode 100644 torchvision/prototype/models/quantization/resnet.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 2879a37dbf9..c77d27e8009 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1,2 +1,3 @@ from .resnet import * from . import detection +from . import quantization diff --git a/torchvision/prototype/models/quantization/__init__.py b/torchvision/prototype/models/quantization/__init__.py new file mode 100644 index 00000000000..b792ca6ecf7 --- /dev/null +++ b/torchvision/prototype/models/quantization/__init__.py @@ -0,0 +1 @@ +from .resnet import * diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py new file mode 100644 index 00000000000..0f68a994c84 --- /dev/null +++ b/torchvision/prototype/models/quantization/resnet.py @@ -0,0 +1,84 @@ +import warnings +from functools import partial +from typing import Any, List, Optional, Type, Union + +from ....models.quantization.resnet import ( + QuantizableBasicBlock, + QuantizableBottleneck, + QuantizableResNet, + _replace_relu, + quantize_model, +) +from ...transforms.presets import ImageNetEval +from .._api import Weights, WeightEntry +from .._meta import _IMAGENET_CATEGORIES +from ..resnet import ResNet50Weights + + +__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"] + + +def _resnet( + block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], + layers: List[int], + weights: Optional[Weights], + progress: bool, + quantize: bool, + **kwargs: Any, +) -> QuantizableResNet: + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + if "backend" in weights.meta: + kwargs["backend"] = weights.meta["backend"] + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableResNet(block, layers, **kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model + + +_common_meta = { + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, + "backend": "fbgemm", +} + + +class QuantizedResNet50Weights(Weights): + ImageNet1K_FBGEMM_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#quantized", + "acc@1": 75.920, + "acc@5": 92.814, + }, + ) + + +def resnet50( + weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + if kwargs.pop("pretrained"): + weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1 + else: + weights = None + + if quantize: + weights = QuantizedResNet50Weights.verify(weights) + else: + weights = ResNet50Weights.verify(weights) + + return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)