diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index f9c2f351af1..8b41bd00bb0 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -24,72 +24,6 @@ } -def inception_v3( - pretrained: bool = False, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> "QuantizableInception3": - - r"""Inception v3 model architecture from - `"Rethinking the Inception Architecture for Computer Vision" `_. - - .. note:: - **Important**: In contrast to the other models the inception_v3 expects tensors with a size of - N x 3 x 299 x 299, so ensure your images are sized accordingly. - - Note that quantize = True returns a quantized model with 8 bit - weights. Quantized models only support inference and run on CPUs. - GPU inference is not yet supported - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - quantize (bool): If True, return a quantized version of the model - aux_logits (bool): If True, add an auxiliary branch that can improve training. - Default: *True* - transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: *False* - """ - if pretrained: - if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" in kwargs: - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - else: - original_aux_logits = False - - model = QuantizableInception3(**kwargs) - _replace_relu(model) - - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" - quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - model_url = quant_model_urls["inception_v3_google_" + backend] - else: - model_url = inception_module.model_urls["inception_v3_google"] - - state_dict = load_state_dict_from_url(model_url, progress=progress) - - model.load_state_dict(state_dict) - - if not quantize: - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - return model - - class QuantizableBasicConv2d(inception_module.BasicConv2d): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -237,3 +171,68 @@ def fuse_model(self) -> None: for m in self.modules(): if type(m) is QuantizableBasicConv2d: m.fuse_model() + + +def inception_v3( + pretrained: bool = False, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableInception3: + r"""Inception v3 model architecture from + `"Rethinking the Inception Architecture for Computer Vision" `_. + + .. note:: + **Important**: In contrast to the other models the inception_v3 expects tensors with a size of + N x 3 x 299 x 299, so ensure your images are sized accordingly. + + Note that quantize = True returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + quantize (bool): If True, return a quantized version of the model + aux_logits (bool): If True, add an auxiliary branch that can improve training. + Default: *True* + transform_input (bool): If True, preprocesses the input according to the method with which it + was trained on ImageNet. Default: *False* + """ + if pretrained: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" in kwargs: + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True + else: + original_aux_logits = False + + model = QuantizableInception3(**kwargs) + _replace_relu(model) + + if quantize: + # TODO use pretrained as a string to specify the backend + backend = "fbgemm" + quantize_model(model, backend) + else: + assert pretrained in [True, False] + + if pretrained: + if quantize: + if not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + model_url = quant_model_urls["inception_v3_google_" + backend] + else: + model_url = inception_module.model_urls["inception_v3_google"] + + state_dict = load_state_dict_from_url(model_url, progress=progress) + + model.load_state_dict(state_dict) + + if not quantize: + if not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + return model diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index 9de499bed2c..daac42a3d3f 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -10,10 +10,10 @@ from ._meta import _IMAGENET_CATEGORIES -__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception3Weights", "inception_v3"] +__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"] -class Inception3Weights(Weights): +class InceptionV3Weights(Weights): ImageNet1K_TFV1 = WeightEntry( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", transforms=partial(ImageNetEval, crop_size=299, resize_size=342), @@ -28,11 +28,11 @@ class Inception3Weights(Weights): ) -def inception_v3(weights: Optional[Inception3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: +def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: if "pretrained" in kwargs: warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = Inception3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None - weights = Inception3Weights.verify(weights) + weights = InceptionV3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None + weights = InceptionV3Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", True) if weights is not None: diff --git a/torchvision/prototype/models/quantization/__init__.py b/torchvision/prototype/models/quantization/__init__.py index e82fed54a9c..92bf41ed968 100644 --- a/torchvision/prototype/models/quantization/__init__.py +++ b/torchvision/prototype/models/quantization/__init__.py @@ -1,2 +1,3 @@ from .googlenet import * +from .inception import * from .resnet import * diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py new file mode 100644 index 00000000000..a783f33d177 --- /dev/null +++ b/torchvision/prototype/models/quantization/inception.py @@ -0,0 +1,87 @@ +import warnings +from functools import partial +from typing import Any, Optional, Union + +from torchvision.transforms.functional import InterpolationMode + +from ....models.quantization.inception import ( + QuantizableInception3, + _replace_relu, + quantize_model, +) +from ...transforms.presets import ImageNetEval +from .._api import Weights, WeightEntry +from .._meta import _IMAGENET_CATEGORIES +from ..inception import InceptionV3Weights + + +__all__ = [ + "QuantizableInception3", + "QuantizedInceptionV3Weights", + "inception_v3", +] + + +class QuantizedInceptionV3Weights(Weights): + ImageNet1K_FBGEMM_TFV1 = WeightEntry( + url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", + transforms=partial(ImageNetEval, crop_size=299, resize_size=342), + meta={ + "size": (299, 299), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": InceptionV3Weights.ImageNet1K_TFV1, + "acc@1": 77.176, + "acc@5": 93.354, + }, + ) + + +def inception_v3( + weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableInception3: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + if kwargs.pop("pretrained"): + weights = ( + QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1 + ) + else: + weights = None + + if quantize: + weights = QuantizedInceptionV3Weights.verify(weights) + else: + weights = InceptionV3Weights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + kwargs["aux_logits"] = True + kwargs["num_classes"] = len(weights.meta["categories"]) + if "backend" in weights.meta: + kwargs["backend"] = weights.meta["backend"] + backend = kwargs.pop("backend", "fbgemm") + + model = QuantizableInception3(**kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + if quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + model.load_state_dict(weights.state_dict(progress=progress)) + if not quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + + return model