From bbe6c8965ede714db6eaa92f0241fabf8601fcdf Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Nov 2021 17:25:57 +0000 Subject: [PATCH 1/3] Moving original builder at the bottom of the page to use proper typing. --- torchvision/models/inception.py | 74 ++++++++++++++++----------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index d63f94119e1..322c2370bdd 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -26,43 +26,6 @@ _InceptionOutputs = InceptionOutputs -def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3": - r"""Inception v3 model architecture from - `"Rethinking the Inception Architecture for Computer Vision" `_. - The required minimum input size of the model is 75x75. - - .. note:: - **Important**: In contrast to the other models the inception_v3 expects tensors with a size of - N x 3 x 299 x 299, so ensure your images are sized accordingly. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - aux_logits (bool): If True, add an auxiliary branch that can improve training. - Default: *True* - transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: *False* - """ - if pretrained: - if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" in kwargs: - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - else: - original_aux_logits = True - kwargs["init_weights"] = False # we are loading weights from a pretrained model - model = Inception3(**kwargs) - state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress) - model.load_state_dict(state_dict) - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - return model - - return Inception3(**kwargs) - - class Inception3(nn.Module): def __init__( self, @@ -442,3 +405,40 @@ def forward(self, x: Tensor) -> Tensor: x = self.conv(x) x = self.bn(x) return F.relu(x, inplace=True) + + +def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Inception3: + r"""Inception v3 model architecture from + `"Rethinking the Inception Architecture for Computer Vision" `_. + The required minimum input size of the model is 75x75. + + .. note:: + **Important**: In contrast to the other models the inception_v3 expects tensors with a size of + N x 3 x 299 x 299, so ensure your images are sized accordingly. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + aux_logits (bool): If True, add an auxiliary branch that can improve training. + Default: *True* + transform_input (bool): If True, preprocesses the input according to the method with which it + was trained on ImageNet. Default: *False* + """ + if pretrained: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" in kwargs: + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True + else: + original_aux_logits = True + kwargs["init_weights"] = False # we are loading weights from a pretrained model + model = Inception3(**kwargs) + state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress) + model.load_state_dict(state_dict) + if not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + return model + + return Inception3(**kwargs) From 56e2b6464d8b844892af7b637dc35235c8a2f43b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Nov 2021 17:43:11 +0000 Subject: [PATCH 2/3] Adding multiweight support to inception. --- torchvision/prototype/models/__init__.py | 1 + torchvision/prototype/models/inception.py | 54 +++++++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 torchvision/prototype/models/inception.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 76f904c2ff3..ef0288f60b0 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -2,6 +2,7 @@ from .densenet import * from .efficientnet import * from .googlenet import * +from .inception import * from .mnasnet import * from .mobilenetv2 import * from .mobilenetv3 import * diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py new file mode 100644 index 00000000000..8e9816bbfb7 --- /dev/null +++ b/torchvision/prototype/models/inception.py @@ -0,0 +1,54 @@ +import warnings +from functools import partial +from typing import Any, Optional + +from torchvision.transforms.functional import InterpolationMode + +from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs +from ..transforms.presets import ImageNetEval +from ._api import Weights, WeightEntry +from ._meta import _IMAGENET_CATEGORIES + + +__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception3Weights", "inception_v3"] + + +_common_meta = {"size": (299, 299), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} + + +class Inception3Weights(Weights): + ImageNet1K_TFV1 = WeightEntry( + url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", + transforms=partial(ImageNetEval, crop_size=299, resize_size=342), + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3", + "acc@1": 77.294, + "acc@5": 93.450, + }, + ) + + +def inception_v3(weights: Optional[Inception3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = Inception3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None + weights = Inception3Weights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", True) + if weights is not None: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + kwargs["aux_logits"] = True + kwargs["init_weights"] = False + kwargs["num_classes"] = len(weights.meta["categories"]) + + model = Inception3(**kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + if not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + + return model From ae3e47f4822b88f95cbb5bf850e80a00d6a7fb9c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Nov 2021 18:04:26 +0000 Subject: [PATCH 3/3] Update doc. --- references/classification/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/classification/README.md b/references/classification/README.md index 006e9c398b1..b53b4331b2f 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -38,7 +38,7 @@ The weights of the Inception V3 model are ported from the original paper rather Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command: ``` -torchrun --nproc_per_node=8 train.py --model inception_v3 +torchrun --nproc_per_node=8 train.py --model inception_v3\ --val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained ```