From 69352353d17c071a8a4ddad8927f6d4f8e60411d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Nov 2021 15:38:42 +0000 Subject: [PATCH 1/2] Change enum name for weights contributed by community. --- torchvision/prototype/models/googlenet.py | 4 ++-- torchvision/prototype/models/mnasnet.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index d9abfbefd8d..8bed730457d 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -17,7 +17,7 @@ class GoogLeNetWeights(Weights): - ImageNet1K_TheCodezV1 = WeightEntry( + ImageNet1K_Community = WeightEntry( url="https://download.pytorch.org/models/googlenet-1378be20.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -32,7 +32,7 @@ class GoogLeNetWeights(Weights): def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: if "pretrained" in kwargs: warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = GoogLeNetWeights.ImageNet1K_TheCodezV1 if kwargs.pop("pretrained") else None + weights = GoogLeNetWeights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = GoogLeNetWeights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index f7a15532f79..c0e7a8c6b8f 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -27,7 +27,7 @@ class MNASNet0_5Weights(Weights): - ImageNet1K_TrainerV1 = WeightEntry( + ImageNet1K_Community = WeightEntry( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -45,7 +45,7 @@ class MNASNet0_75Weights(Weights): class MNASNet1_0Weights(Weights): - ImageNet1K_TrainerV1 = WeightEntry( + ImageNet1K_Community = WeightEntry( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -77,7 +77,7 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if "pretrained" in kwargs: warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = MNASNet0_5Weights.ImageNet1K_TrainerV1 if kwargs.pop("pretrained") else None + weights = MNASNet0_5Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = MNASNet0_5Weights.verify(weights) @@ -98,7 +98,7 @@ def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = T def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if "pretrained" in kwargs: warnings.warn("The argument pretrained is deprecated, please use weights instead.") - weights = MNASNet1_0Weights.ImageNet1K_TrainerV1 if kwargs.pop("pretrained") else None + weights = MNASNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None weights = MNASNet1_0Weights.verify(weights) return _mnasnet(1.0, weights, progress, **kwargs) From ef25dcf70d115385629c561b74547c8b1c172b10 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 1 Nov 2021 16:00:39 +0000 Subject: [PATCH 2/2] Adding multiweight support to squeezenet. --- torchvision/prototype/models/__init__.py | 1 + torchvision/prototype/models/squeezenet.py | 74 ++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 torchvision/prototype/models/squeezenet.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 9c023dbcbbf..76f904c2ff3 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -8,6 +8,7 @@ from .regnet import * from .resnet import * from .shufflenetv2 import * +from .squeezenet import * from .vgg import * from . import detection from . import quantization diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py new file mode 100644 index 00000000000..9f9c09b8de7 --- /dev/null +++ b/torchvision/prototype/models/squeezenet.py @@ -0,0 +1,74 @@ +import warnings +from functools import partial +from typing import Any, Optional + +from torchvision.transforms.functional import InterpolationMode + +from ...models.squeezenet import SqueezeNet +from ..transforms.presets import ImageNetEval +from ._api import Weights, WeightEntry +from ._meta import _IMAGENET_CATEGORIES + + +__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"] + + +_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} + + +class SqueezeNet1_0Weights(Weights): + ImageNet1K_Community = WeightEntry( + url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717", + "acc@1": 58.092, + "acc@5": 80.420, + }, + ) + + +class SqueezeNet1_1Weights(Weights): + ImageNet1K_Community = WeightEntry( + url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717", + "acc@1": 58.178, + "acc@5": 80.624, + }, + ) + + +def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = SqueezeNet1_0Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None + weights = SqueezeNet1_0Weights.verify(weights) + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + + model = SqueezeNet("1_0", **kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model + + +def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = SqueezeNet1_1Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None + weights = SqueezeNet1_1Weights.verify(weights) + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + + model = SqueezeNet("1_1", **kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model