From d44e44d4059157c8a24d05b5a8c0e26926dc6cab Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 15:16:22 +0000 Subject: [PATCH 01/45] Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet. --- test/test_prototype_models.py | 2 +- torchvision/{prototype => }/models/_api.py | 2 +- torchvision/{prototype => }/models/_meta.py | 0 torchvision/models/_utils.py | 108 ++++- torchvision/models/alexnet.py | 56 ++- torchvision/models/convnext.py | 144 ++++-- torchvision/models/densenet.py | 137 +++++- torchvision/models/efficientnet.py | 450 +++++++++++++---- torchvision/prototype/models/__init__.py | 20 - torchvision/prototype/models/_utils.py | 108 ----- torchvision/prototype/models/alexnet.py | 49 -- torchvision/prototype/models/convnext.py | 169 ------- torchvision/prototype/models/densenet.py | 159 ------ .../prototype/models/detection/faster_rcnn.py | 4 +- .../prototype/models/detection/fcos.py | 4 +- .../models/detection/keypoint_rcnn.py | 4 +- .../prototype/models/detection/mask_rcnn.py | 4 +- .../prototype/models/detection/retinanet.py | 4 +- torchvision/prototype/models/detection/ssd.py | 4 +- .../prototype/models/detection/ssdlite.py | 4 +- torchvision/prototype/models/efficientnet.py | 453 ------------------ torchvision/prototype/models/googlenet.py | 4 +- torchvision/prototype/models/inception.py | 4 +- torchvision/prototype/models/mnasnet.py | 4 +- torchvision/prototype/models/mobilenet.py | 6 - torchvision/prototype/models/mobilenetv2.py | 4 +- torchvision/prototype/models/mobilenetv3.py | 4 +- .../prototype/models/optical_flow/raft.py | 4 +- .../models/quantization/googlenet.py | 4 +- .../models/quantization/inception.py | 4 +- .../models/quantization/mobilenetv2.py | 4 +- .../models/quantization/mobilenetv3.py | 4 +- .../prototype/models/quantization/resnet.py | 4 +- .../models/quantization/shufflenetv2.py | 4 +- torchvision/prototype/models/regnet.py | 4 +- torchvision/prototype/models/resnet.py | 4 +- .../models/segmentation/deeplabv3.py | 4 +- .../prototype/models/segmentation/fcn.py | 4 +- .../prototype/models/segmentation/lraspp.py | 4 +- torchvision/prototype/models/shufflenetv2.py | 4 +- torchvision/prototype/models/squeezenet.py | 4 +- torchvision/prototype/models/vgg.py | 4 +- torchvision/prototype/models/video/resnet.py | 4 +- .../prototype/models/vision_transformer.py | 4 +- torchvision/prototype/transforms/__init__.py | 9 - torchvision/transforms/__init__.py | 7 + .../{prototype => }/transforms/_presets.py | 2 +- 47 files changed, 807 insertions(+), 1190 deletions(-) rename torchvision/{prototype => }/models/_api.py (98%) rename torchvision/{prototype => }/models/_meta.py (100%) delete mode 100644 torchvision/prototype/models/__init__.py delete mode 100644 torchvision/prototype/models/_utils.py delete mode 100644 torchvision/prototype/models/alexnet.py delete mode 100644 torchvision/prototype/models/convnext.py delete mode 100644 torchvision/prototype/models/densenet.py delete mode 100644 torchvision/prototype/models/efficientnet.py delete mode 100644 torchvision/prototype/models/mobilenet.py rename torchvision/{prototype => }/transforms/_presets.py (98%) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 6c7234e2ef0..6f1abdfd466 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -6,7 +6,7 @@ import torch from common_utils import cpu_and_gpu, needs_cuda from torchvision.prototype import models -from torchvision.prototype.models._api import WeightsEnum, Weights +from torchvision.models._api import WeightsEnum, Weights from torchvision.prototype.models._utils import handle_legacy_interface run_if_test_with_prototype = pytest.mark.skipif( diff --git a/torchvision/prototype/models/_api.py b/torchvision/models/_api.py similarity index 98% rename from torchvision/prototype/models/_api.py rename to torchvision/models/_api.py index 85b280a7dfc..d841415a45a 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/models/_api.py @@ -7,7 +7,7 @@ from torchvision._utils import StrEnum -from ..._internally_replaced_utils import load_state_dict_from_url +from .._internally_replaced_utils import load_state_dict_from_url __all__ = ["WeightsEnum", "Weights", "get_weight"] diff --git a/torchvision/prototype/models/_meta.py b/torchvision/models/_meta.py similarity index 100% rename from torchvision/prototype/models/_meta.py rename to torchvision/models/_meta.py diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index f4e1cd84508..1deff604b4a 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,7 +1,12 @@ +import functools +import warnings from collections import OrderedDict -from typing import Dict, Optional +from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union from torch import nn +from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw + +from ._api import WeightsEnum class IntermediateLayerGetter(nn.ModuleDict): @@ -81,3 +86,104 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> if new_v < 0.9 * v: new_v += divisor return new_v + + +W = TypeVar("W", bound=WeightsEnum) +M = TypeVar("M", bound=nn.Module) +V = TypeVar("V") + + +def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): + """Decorates a model builder with the new interface to make it compatible with the old. + + In particular this handles two things: + + 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See + :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. + 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to + ``weights=Weights`` and emits a deprecation warning with instructions for the new interface. + + Args: + **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter + name and default value for the legacy ``pretrained=True``. The default value can be a callable in which + case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in + the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters + should be accessed with :meth:`~dict.get`. + """ + + def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: + @kwonly_to_pos_or_kw + @functools.wraps(builder) + def inner_wrapper(*args: Any, **kwargs: Any) -> M: + for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr] + # If neither the weights nor the pretrained parameter as passed, or the weights argument already use + # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the + # weight argument, since it is a valid value. + sentinel = object() + weights_arg = kwargs.get(weights_param, sentinel) + if ( + (weights_param not in kwargs and pretrained_param not in kwargs) + or isinstance(weights_arg, WeightsEnum) + or (isinstance(weights_arg, str) and weights_arg != "legacy") + or weights_arg is None + ): + continue + + # If the pretrained parameter was passed as positional argument, it is now mapped to + # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current + # signature to infer the names of positionally passed arguments and thus has no knowledge that there + # used to be a pretrained parameter. + pretrained_positional = weights_arg is not sentinel + if pretrained_positional: + # We put the pretrained argument under its legacy name in the keyword argument dictionary to have a + # unified access to the value if the default value is a callable. + kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) + else: + pretrained_arg = kwargs[pretrained_param] + + if pretrained_arg: + default_weights_arg = default(kwargs) if callable(default) else default + if not isinstance(default_weights_arg, WeightsEnum): + raise ValueError(f"No weights available for model {builder.__name__}") + else: + default_weights_arg = None + + if not pretrained_positional: + warnings.warn( + f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead." + ) + + msg = ( + f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. " + f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." + ) + if pretrained_arg: + msg = ( + f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` " + f"to get the most up-to-date weights." + ) + warnings.warn(msg) + + del kwargs[pretrained_param] + kwargs[weights_param] = default_weights_arg + + return builder(*args, **kwargs) + + return inner_wrapper + + return outer_wrapper + + +def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: + if param in kwargs: + if kwargs[param] != new_value: + raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.") + else: + kwargs[param] = new_value + + +def _ovewrite_value_param(param: Optional[V], new_value: V) -> V: + if param is not None: + if param != new_value: + raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.") + return new_value diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index bb812febdc4..ba8d6eef37a 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -1,18 +1,18 @@ -from typing import Any +from functools import partial +from typing import Any, Optional import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.transforms import ImageClassificationEval +from torchvision.transforms.functional import InterpolationMode -__all__ = ["AlexNet", "alexnet"] - - -model_urls = { - "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", -} +__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] class AlexNet(nn.Module): @@ -53,17 +53,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: +class AlexNet_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "AlexNet", + "publication_year": 2012, + "num_params": 61100840, + "size": (224, 224), + "min_size": (63, 63), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", + "acc@1": 56.522, + "acc@5": 79.066, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) +def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: r"""AlexNet model architecture from the `"One weird trick..." `_ paper. The required minimum input size of the model is 63x63. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (AlexNet_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = AlexNet_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = AlexNet(**kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress) - model.load_state_dict(state_dict) - return model + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model \ No newline at end of file diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 3a0dcdb31cd..40811841b53 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -1,18 +1,27 @@ from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence import torch from torch import nn, Tensor from torch.nn import functional as F -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth from ..utils import _log_api_usage_once +from ..transforms import ImageClassificationEval, InterpolationMode + +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param + __all__ = [ "ConvNeXt", + "ConvNeXt_Tiny_Weights", + "ConvNeXt_Small_Weights", + "ConvNeXt_Base_Weights", + "ConvNeXt_Large_Weights", "convnext_tiny", "convnext_small", "convnext_base", @@ -20,14 +29,6 @@ ] -_MODELS_URLS: Dict[str, Optional[str]] = { - "convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - "convnext_small": "https://download.pytorch.org/models/convnext_small-0c510722.pth", - "convnext_base": "https://download.pytorch.org/models/convnext_base-6075fbad.pth", - "convnext_large": "https://download.pytorch.org/models/convnext_large-ea097f82.pth", -} - - class LayerNorm2d(nn.LayerNorm): def forward(self, x: Tensor) -> Tensor: x = x.permute(0, 2, 3, 1) @@ -187,29 +188,101 @@ def forward(self, x: Tensor) -> Tensor: def _convnext( - arch: str, block_setting: List[CNBlockConfig], stochastic_depth_prob: float, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> ConvNeXt: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) - if pretrained: - if arch not in _MODELS_URLS: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +_COMMON_META = { + "task": "image_classification", + "architecture": "ConvNeXt", + "publication_year": 2022, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", +} + + +class ConvNeXt_Tiny_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236), + meta={ + **_COMMON_META, + "num_params": 28589128, + "acc@1": 82.520, + "acc@5": 96.146, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Small_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_small-0c510722.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230), + meta={ + **_COMMON_META, + "num_params": 50223688, + "acc@1": 83.616, + "acc@5": 96.650, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Base_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 88591464, + "acc@1": 84.062, + "acc@5": 96.870, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ConvNeXt_Large_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 197767336, + "acc@1": 84.414, + "acc@5": 96.976, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) +def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Tiny model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Tiny_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Tiny_Weights.verify(weights) + block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), @@ -217,16 +290,21 @@ def convnext_tiny(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) - return _convnext("convnext_tiny", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) +def convnext_small( + *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: r"""ConvNeXt Small model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Small_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Small_Weights.verify(weights) + block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), @@ -234,16 +312,19 @@ def convnext_small(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(768, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) - return _convnext("convnext_small", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) +def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: r"""ConvNeXt Base model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Base_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Base_Weights.verify(weights) + block_setting = [ CNBlockConfig(128, 256, 3), CNBlockConfig(256, 512, 3), @@ -251,16 +332,21 @@ def convnext_base(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(1024, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext("convnext_base", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) -def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) +def convnext_large( + *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any +) -> ConvNeXt: r"""ConvNeXt Large model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ConvNeXt_Large_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ConvNeXt_Large_Weights.verify(weights) + block_setting = [ CNBlockConfig(192, 384, 3), CNBlockConfig(384, 768, 3), @@ -268,4 +354,4 @@ def convnext_large(*, pretrained: bool = False, progress: bool = True, **kwargs: CNBlockConfig(1536, None, 3), ] stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext("convnext_large", block_setting, stochastic_depth_prob, pretrained, progress, **kwargs) + return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 14e318360af..4da690a979f 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -1,6 +1,7 @@ import re +from functools import partial from collections import OrderedDict -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple import torch import torch.nn as nn @@ -8,18 +9,25 @@ import torch.utils.checkpoint as cp from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["DenseNet", "densenet121", "densenet169", "densenet201", "densenet161"] -model_urls = { - "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", - "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", - "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", - "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", -} +__all__ = [ + "DenseNet", + "DenseNet121_Weights", + "DenseNet161_Weights", + "DenseNet169_Weights", + "DenseNet201_Weights", + "densenet121", + "densenet161", + "densenet169", + "densenet201", +] class _DenseLayer(nn.Module): @@ -220,7 +228,7 @@ def forward(self, x: Tensor) -> Tensor: return out -def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: +def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None: # '.'s are no longer allowed in module names, but previous _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used @@ -229,7 +237,7 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" ) - state_dict = load_state_dict_from_url(model_url, progress=progress) + state_dict = weights.get_state_dict(progress=progress) for key in list(state_dict.keys()): res = pattern.match(key) if res: @@ -240,21 +248,93 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: def _densenet( - arch: str, growth_rate: int, block_config: Tuple[int, int, int, int], num_init_features: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> DenseNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) - if pretrained: - _load_state_dict(model, model_urls[arch], progress) + + if weights is not None: + _load_state_dict(model=model, weights=weights, progress=progress) + return model -def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: +_COMMON_META = { + "task": "image_classification", + "architecture": "DenseNet", + "publication_year": 2016, + "size": (224, 224), + "min_size": (29, 29), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/pull/116", +} + +class DenseNet121_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet121-a639ec97.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 7978856, + "acc@1": 74.434, + "acc@5": 91.972, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class DenseNet161_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet161-8d451a50.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 28681000, + "acc@1": 77.138, + "acc@5": 93.560, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class DenseNet169_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 14149480, + "acc@1": 75.600, + "acc@5": 92.806, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class DenseNet201_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/densenet201-c1103571.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 20013928, + "acc@1": 76.896, + "acc@5": 93.370, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1)) +def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-121 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. @@ -265,10 +345,13 @@ def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, **kwargs) + weights = DenseNet121_Weights.verify(weights) + return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) -def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: + +@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1)) +def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-161 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. @@ -279,10 +362,13 @@ def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, **kwargs) + weights = DenseNet161_Weights.verify(weights) + + return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) -def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: +@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1)) +def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-169 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. @@ -293,10 +379,13 @@ def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, **kwargs) + weights = DenseNet169_Weights.verify(weights) + return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) -def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: + +@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1)) +def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: r"""Densenet-201 model from `"Densely Connected Convolutional Networks" `_. The required minimum input size of the model is 29x29. @@ -307,4 +396,6 @@ def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet("densenet201", 32, (6, 12, 48, 32), 64, pretrained, progress, **kwargs) + weights = DenseNet201_Weights.verify(weights) + + return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) \ No newline at end of file diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index f8238912ffd..21e19c7d01a 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -9,14 +9,28 @@ from torch import nn, Tensor from torchvision.ops import StochasticDepth -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation, SqueezeExcitation +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible + +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible __all__ = [ "EfficientNet", + "EfficientNet_B0_Weights", + "EfficientNet_B1_Weights", + "EfficientNet_B2_Weights", + "EfficientNet_B3_Weights", + "EfficientNet_B4_Weights", + "EfficientNet_B5_Weights", + "EfficientNet_B6_Weights", + "EfficientNet_B7_Weights", + "EfficientNet_V2_S_Weights", + "EfficientNet_V2_M_Weights", + "EfficientNet_V2_L_Weights", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", @@ -31,25 +45,6 @@ ] -model_urls = { - # Weights ported from https://github.com/rwightman/pytorch-image-models/ - "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", - "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", - "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", - "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", - "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", - # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/ - "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", - "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", - "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", - # Weights trained with TorchVision - "efficientnet_v2_s": "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", - "efficientnet_v2_m": "https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", - # Weights ported from TF - "efficientnet_v2_l": "https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", -} - - @dataclass class _MBConvConfig: expand_ratio: float @@ -362,20 +357,21 @@ def forward(self, x: Tensor) -> Tensor: def _efficientnet( - arch: str, inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], dropout: float, last_channel: Optional[int], - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> EfficientNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) - if pretrained: - if model_urls.get(arch, None) is None: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model @@ -434,208 +430,484 @@ def _efficientnet_conf( return inverted_residual_setting, last_channel -def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +_COMMON_META = { + "task": "image_classification", + "categories": _IMAGENET_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", +} + + +_COMMON_META_V1 = { + **_COMMON_META, + "architecture": "EfficientNet", + "publication_year": 2019, + "interpolation": InterpolationMode.BICUBIC, + "min_size": (1, 1), +} + + +_COMMON_META_V2 = { + **_COMMON_META, + "architecture": "EfficientNetV2", + "publication_year": 2021, + "interpolation": InterpolationMode.BILINEAR, + "min_size": (33, 33), +} + + +class EfficientNet_B0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + transforms=partial( + ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 5288548, + "size": (224, 224), + "acc@1": 77.692, + "acc@5": 93.532, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B1_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + transforms=partial( + ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 7794184, + "size": (240, 240), + "acc@1": 78.642, + "acc@5": 94.186, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", + transforms=partial( + ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR + ), + meta={ + **_COMMON_META_V1, + "num_params": 7794184, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", + "interpolation": InterpolationMode.BILINEAR, + "size": (240, 240), + "acc@1": 79.838, + "acc@5": 94.934, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class EfficientNet_B2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + transforms=partial( + ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 9109994, + "size": (288, 288), + "acc@1": 80.608, + "acc@5": 95.310, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B3_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + transforms=partial( + ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 12233232, + "size": (300, 300), + "acc@1": 82.008, + "acc@5": 96.054, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B4_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", + transforms=partial( + ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 19341616, + "size": (380, 380), + "acc@1": 83.384, + "acc@5": 96.594, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + transforms=partial( + ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 30389784, + "size": (456, 456), + "acc@1": 83.444, + "acc@5": 96.628, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B6_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + transforms=partial( + ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 43040704, + "size": (528, 528), + "acc@1": 84.008, + "acc@5": 96.916, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B7_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", + transforms=partial( + ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 66347960, + "size": (600, 600), + "acc@1": 84.122, + "acc@5": 96.908, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", + transforms=partial( + ImageClassificationEval, + crop_size=384, + resize_size=384, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 21458488, + "size": (384, 384), + "acc@1": 84.228, + "acc@5": 96.878, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_M_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", + transforms=partial( + ImageClassificationEval, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 54139356, + "size": (480, 480), + "acc@1": 85.112, + "acc@5": 97.156, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_L_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", + transforms=partial( + ImageClassificationEval, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BICUBIC, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + ), + meta={ + **_COMMON_META_V2, + "num_params": 118515272, + "size": (480, 480), + "acc@1": 85.808, + "acc@5": 97.788, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) +def efficientnet_b0( + *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B0 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B0_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b0" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.0) - return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B0_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) -def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) +def efficientnet_b1( + *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B1 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B1_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b1" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.1) - return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B1_Weights.verify(weights) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) -def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) +def efficientnet_b2( + *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B2 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B2_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b2" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.1, depth_mult=1.2) - return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B2_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) -def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) +def efficientnet_b3( + *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B3 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B3_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b3" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.2, depth_mult=1.4) - return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B3_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) -def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) +def efficientnet_b4( + *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B4 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B4_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b4" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.4, depth_mult=1.8) - return _efficientnet(arch, inverted_residual_setting, 0.4, last_channel, pretrained, progress, **kwargs) + weights = EfficientNet_B4_Weights.verify(weights) + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8) + return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) -def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) +def efficientnet_b5( + *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B5 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B5_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b5" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.6, depth_mult=2.2) + weights = EfficientNet_B5_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2) return _efficientnet( - arch, inverted_residual_setting, 0.4, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) -def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1)) +def efficientnet_b6( + *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B6 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B6_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b6" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.8, depth_mult=2.6) + weights = EfficientNet_B6_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6) return _efficientnet( - arch, inverted_residual_setting, 0.5, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) -def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1)) +def efficientnet_b7( + *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs a EfficientNet B7 architecture from `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_B7_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_b7" - inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=2.0, depth_mult=3.1) + weights = EfficientNet_B7_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1) return _efficientnet( - arch, inverted_residual_setting, 0.5, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs, ) -def efficientnet_v2_s(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) +def efficientnet_v2_s( + *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs an EfficientNetV2-S architecture from `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_V2_S_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_v2_s" - inverted_residual_setting, last_channel = _efficientnet_conf(arch) + weights = EfficientNet_V2_S_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s") return _efficientnet( - arch, inverted_residual_setting, 0.2, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=1e-03), **kwargs, ) -def efficientnet_v2_m(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) +def efficientnet_v2_m( + *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs an EfficientNetV2-M architecture from `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_V2_M_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_v2_m" - inverted_residual_setting, last_channel = _efficientnet_conf(arch) + weights = EfficientNet_V2_M_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m") return _efficientnet( - arch, inverted_residual_setting, 0.3, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=1e-03), **kwargs, ) -def efficientnet_v2_l(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) +def efficientnet_v2_l( + *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: """ Constructs an EfficientNetV2-L architecture from `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (EfficientNet_V2_L_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "efficientnet_v2_l" - inverted_residual_setting, last_channel = _efficientnet_conf(arch) + weights = EfficientNet_V2_L_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l") return _efficientnet( - arch, inverted_residual_setting, 0.4, last_channel, - pretrained, + weights, progress, norm_layer=partial(nn.BatchNorm2d, eps=1e-03), **kwargs, diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py deleted file mode 100644 index 83e49908348..00000000000 --- a/torchvision/prototype/models/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from .alexnet import * -from .convnext import * -from .densenet import * -from .efficientnet import * -from .googlenet import * -from .inception import * -from .mnasnet import * -from .mobilenet import * -from .regnet import * -from .resnet import * -from .shufflenetv2 import * -from .squeezenet import * -from .vgg import * -from .vision_transformer import * -from . import detection -from . import optical_flow -from . import quantization -from . import segmentation -from . import video -from ._api import get_weight diff --git a/torchvision/prototype/models/_utils.py b/torchvision/prototype/models/_utils.py deleted file mode 100644 index cc9f7dcfc36..00000000000 --- a/torchvision/prototype/models/_utils.py +++ /dev/null @@ -1,108 +0,0 @@ -import functools -import warnings -from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union - -from torch import nn -from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw - -from ._api import WeightsEnum - -W = TypeVar("W", bound=WeightsEnum) -M = TypeVar("M", bound=nn.Module) -V = TypeVar("V") - - -def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): - """Decorates a model builder with the new interface to make it compatible with the old. - - In particular this handles two things: - - 1. Allows positional parameters again, but emits a deprecation warning in case they are used. See - :func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. - 2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to - ``weights=Weights`` and emits a deprecation warning with instructions for the new interface. - - Args: - **weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter - name and default value for the legacy ``pretrained=True``. The default value can be a callable in which - case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in - the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters - should be accessed with :meth:`~dict.get`. - """ - - def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: - @kwonly_to_pos_or_kw - @functools.wraps(builder) - def inner_wrapper(*args: Any, **kwargs: Any) -> M: - for weights_param, (pretrained_param, default) in weights.items(): # type: ignore[union-attr] - # If neither the weights nor the pretrained parameter as passed, or the weights argument already use - # the new style arguments, there is nothing to do. Note that we cannot use `None` as sentinel for the - # weight argument, since it is a valid value. - sentinel = object() - weights_arg = kwargs.get(weights_param, sentinel) - if ( - (weights_param not in kwargs and pretrained_param not in kwargs) - or isinstance(weights_arg, WeightsEnum) - or (isinstance(weights_arg, str) and weights_arg != "legacy") - or weights_arg is None - ): - continue - - # If the pretrained parameter was passed as positional argument, it is now mapped to - # `kwargs[weights_param]`. This happens because the @kwonly_to_pos_or_kw decorator uses the current - # signature to infer the names of positionally passed arguments and thus has no knowledge that there - # used to be a pretrained parameter. - pretrained_positional = weights_arg is not sentinel - if pretrained_positional: - # We put the pretrained argument under its legacy name in the keyword argument dictionary to have a - # unified access to the value if the default value is a callable. - kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) - else: - pretrained_arg = kwargs[pretrained_param] - - if pretrained_arg: - default_weights_arg = default(kwargs) if callable(default) else default - if not isinstance(default_weights_arg, WeightsEnum): - raise ValueError(f"No weights available for model {builder.__name__}") - else: - default_weights_arg = None - - if not pretrained_positional: - warnings.warn( - f"The parameter '{pretrained_param}' is deprecated, please use '{weights_param}' instead." - ) - - msg = ( - f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated. " - f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." - ) - if pretrained_arg: - msg = ( - f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` " - f"to get the most up-to-date weights." - ) - warnings.warn(msg) - - del kwargs[pretrained_param] - kwargs[weights_param] = default_weights_arg - - return builder(*args, **kwargs) - - return inner_wrapper - - return outer_wrapper - - -def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: - if param in kwargs: - if kwargs[param] != new_value: - raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.") - else: - kwargs[param] = new_value - - -def _ovewrite_value_param(param: Optional[V], new_value: V) -> V: - if param is not None: - if param != new_value: - raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.") - return new_value diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py deleted file mode 100644 index 204a68236d3..00000000000 --- a/torchvision/prototype/models/alexnet.py +++ /dev/null @@ -1,49 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.alexnet import AlexNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] - - -class AlexNet_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "AlexNet", - "publication_year": 2012, - "num_params": 61100840, - "size": (224, 224), - "min_size": (63, 63), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", - "acc@1": 56.522, - "acc@5": 79.066, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1)) -def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: - weights = AlexNet_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = AlexNet(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py deleted file mode 100644 index 7d63ee155db..00000000000 --- a/torchvision/prototype/models/convnext.py +++ /dev/null @@ -1,169 +0,0 @@ -from functools import partial -from typing import Any, List, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.convnext import ConvNeXt, CNBlockConfig -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "ConvNeXt", - "ConvNeXt_Tiny_Weights", - "ConvNeXt_Small_Weights", - "ConvNeXt_Base_Weights", - "ConvNeXt_Large_Weights", - "convnext_tiny", - "convnext_small", - "convnext_base", - "convnext_large", -] - - -def _convnext( - block_setting: List[CNBlockConfig], - stochastic_depth_prob: float, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> ConvNeXt: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ConvNeXt", - "publication_year": 2022, - "size": (224, 224), - "min_size": (32, 32), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", -} - - -class ConvNeXt_Tiny_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236), - meta={ - **_COMMON_META, - "num_params": 28589128, - "acc@1": 82.520, - "acc@5": 96.146, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ConvNeXt_Small_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_small-0c510722.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230), - meta={ - **_COMMON_META, - "num_params": 50223688, - "acc@1": 83.616, - "acc@5": 96.650, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ConvNeXt_Base_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 88591464, - "acc@1": 84.062, - "acc@5": 96.870, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ConvNeXt_Large_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 197767336, - "acc@1": 84.414, - "acc@5": 96.976, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) -def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - weights = ConvNeXt_Tiny_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(96, 192, 3), - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 9), - CNBlockConfig(768, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1)) -def convnext_small( - *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any -) -> ConvNeXt: - weights = ConvNeXt_Small_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(96, 192, 3), - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 27), - CNBlockConfig(768, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1)) -def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - weights = ConvNeXt_Base_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(128, 256, 3), - CNBlockConfig(256, 512, 3), - CNBlockConfig(512, 1024, 27), - CNBlockConfig(1024, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1)) -def convnext_large( - *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any -) -> ConvNeXt: - weights = ConvNeXt_Large_Weights.verify(weights) - - block_setting = [ - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 3), - CNBlockConfig(768, 1536, 27), - CNBlockConfig(1536, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5) - return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py deleted file mode 100644 index 4ad9be028e5..00000000000 --- a/torchvision/prototype/models/densenet.py +++ /dev/null @@ -1,159 +0,0 @@ -import re -from functools import partial -from typing import Any, Optional, Tuple - -import torch.nn as nn -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.densenet import DenseNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "DenseNet", - "DenseNet121_Weights", - "DenseNet161_Weights", - "DenseNet169_Weights", - "DenseNet201_Weights", - "densenet121", - "densenet161", - "densenet169", - "densenet201", -] - - -def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None: - # '.'s are no longer allowed in module names, but previous _DenseLayer - # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. - # They are also in the checkpoints in model_urls. This pattern is used - # to find such keys. - pattern = re.compile( - r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" - ) - - state_dict = weights.get_state_dict(progress=progress) - for key in list(state_dict.keys()): - res = pattern.match(key) - if res: - new_key = res.group(1) + res.group(2) - state_dict[new_key] = state_dict[key] - del state_dict[key] - model.load_state_dict(state_dict) - - -def _densenet( - growth_rate: int, - block_config: Tuple[int, int, int, int], - num_init_features: int, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> DenseNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) - - if weights is not None: - _load_state_dict(model=model, weights=weights, progress=progress) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "DenseNet", - "publication_year": 2016, - "size": (224, 224), - "min_size": (29, 29), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/pull/116", -} - - -class DenseNet121_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet121-a639ec97.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 7978856, - "acc@1": 74.434, - "acc@5": 91.972, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class DenseNet161_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet161-8d451a50.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 28681000, - "acc@1": 77.138, - "acc@5": 93.560, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class DenseNet169_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 14149480, - "acc@1": 75.600, - "acc@5": 92.806, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class DenseNet201_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/densenet201-c1103571.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 20013928, - "acc@1": 76.896, - "acc@5": 93.370, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1)) -def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet121_Weights.verify(weights) - - return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1)) -def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet161_Weights.verify(weights) - - return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1)) -def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet169_Weights.verify(weights) - - return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1)) -def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: - weights = DenseNet201_Weights.verify(weights) - - return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index ecdd9bdb423..727da793605 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -13,8 +13,8 @@ misc_nn_ops, overwrite_eps, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..resnet import ResNet50_Weights, resnet50 diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py index db3a679a62d..153fa10c22c 100644 --- a/torchvision/prototype/models/detection/fcos.py +++ b/torchvision/prototype/models/detection/fcos.py @@ -11,8 +11,8 @@ LastLevelP6P7, misc_nn_ops, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..resnet import ResNet50_Weights, resnet50 diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index e0b4d7061fa..41142da2a34 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -11,8 +11,8 @@ misc_nn_ops, overwrite_eps, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..resnet import ResNet50_Weights, resnet50 diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index 187bf6912b4..df598553686 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -11,8 +11,8 @@ misc_nn_ops, overwrite_eps, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..resnet import ResNet50_Weights, resnet50 diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index eadd6c635ca..7a021ec27c0 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -12,8 +12,8 @@ misc_nn_ops, overwrite_eps, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..resnet import ResNet50_Weights, resnet50 diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 3cab044958d..5aae8c49055 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -10,8 +10,8 @@ DefaultBoxGenerator, SSD, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..vgg import VGG16_Weights, vgg16 diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 6de34acb5ae..7623b4ea861 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -15,8 +15,8 @@ SSD, SSDLiteHead, ) -from .._api import WeightsEnum, Weights -from .._meta import _COCO_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py deleted file mode 100644 index cb6d2bb2b35..00000000000 --- a/torchvision/prototype/models/efficientnet.py +++ /dev/null @@ -1,453 +0,0 @@ -from functools import partial -from typing import Any, Optional, Sequence, Union - -from torch import nn -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "EfficientNet", - "EfficientNet_B0_Weights", - "EfficientNet_B1_Weights", - "EfficientNet_B2_Weights", - "EfficientNet_B3_Weights", - "EfficientNet_B4_Weights", - "EfficientNet_B5_Weights", - "EfficientNet_B6_Weights", - "EfficientNet_B7_Weights", - "EfficientNet_V2_S_Weights", - "EfficientNet_V2_M_Weights", - "EfficientNet_V2_L_Weights", - "efficientnet_b0", - "efficientnet_b1", - "efficientnet_b2", - "efficientnet_b3", - "efficientnet_b4", - "efficientnet_b5", - "efficientnet_b6", - "efficientnet_b7", - "efficientnet_v2_s", - "efficientnet_v2_m", - "efficientnet_v2_l", -] - - -def _efficientnet( - inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], - dropout: float, - last_channel: Optional[int], - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> EfficientNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "categories": _IMAGENET_CATEGORIES, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", -} - - -_COMMON_META_V1 = { - **_COMMON_META, - "architecture": "EfficientNet", - "publication_year": 2019, - "interpolation": InterpolationMode.BICUBIC, - "min_size": (1, 1), -} - - -_COMMON_META_V2 = { - **_COMMON_META, - "architecture": "EfficientNetV2", - "publication_year": 2021, - "interpolation": InterpolationMode.BILINEAR, - "min_size": (33, 33), -} - - -class EfficientNet_B0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", - transforms=partial( - ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 5288548, - "size": (224, 224), - "acc@1": 77.692, - "acc@5": 93.532, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B1_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", - transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 7794184, - "size": (240, 240), - "acc@1": 78.642, - "acc@5": 94.186, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", - transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR - ), - meta={ - **_COMMON_META_V1, - "num_params": 7794184, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", - "interpolation": InterpolationMode.BILINEAR, - "size": (240, 240), - "acc@1": 79.838, - "acc@5": 94.934, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class EfficientNet_B2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", - transforms=partial( - ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 9109994, - "size": (288, 288), - "acc@1": 80.608, - "acc@5": 95.310, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B3_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", - transforms=partial( - ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 12233232, - "size": (300, 300), - "acc@1": 82.008, - "acc@5": 96.054, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B4_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", - transforms=partial( - ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 19341616, - "size": (380, 380), - "acc@1": 83.384, - "acc@5": 96.594, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B5_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", - transforms=partial( - ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 30389784, - "size": (456, 456), - "acc@1": 83.444, - "acc@5": 96.628, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B6_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", - transforms=partial( - ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 43040704, - "size": (528, 528), - "acc@1": 84.008, - "acc@5": 96.916, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_B7_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", - transforms=partial( - ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC - ), - meta={ - **_COMMON_META_V1, - "num_params": 66347960, - "size": (600, 600), - "acc@1": 84.122, - "acc@5": 96.908, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_V2_S_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", - transforms=partial( - ImageClassificationEval, - crop_size=384, - resize_size=384, - interpolation=InterpolationMode.BILINEAR, - ), - meta={ - **_COMMON_META_V2, - "num_params": 21458488, - "size": (384, 384), - "acc@1": 84.228, - "acc@5": 96.878, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_V2_M_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", - transforms=partial( - ImageClassificationEval, - crop_size=480, - resize_size=480, - interpolation=InterpolationMode.BILINEAR, - ), - meta={ - **_COMMON_META_V2, - "num_params": 54139356, - "size": (480, 480), - "acc@1": 85.112, - "acc@5": 97.156, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class EfficientNet_V2_L_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", - transforms=partial( - ImageClassificationEval, - crop_size=480, - resize_size=480, - interpolation=InterpolationMode.BICUBIC, - mean=(0.5, 0.5, 0.5), - std=(0.5, 0.5, 0.5), - ), - meta={ - **_COMMON_META_V2, - "num_params": 118515272, - "size": (480, 480), - "acc@1": 85.808, - "acc@5": 97.788, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) -def efficientnet_b0( - *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B0_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0) - return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) -def efficientnet_b1( - *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B1_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1) - return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) -def efficientnet_b2( - *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B2_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2) - return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) -def efficientnet_b3( - *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B3_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4) - return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) -def efficientnet_b4( - *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B4_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8) - return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) -def efficientnet_b5( - *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B5_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2) - return _efficientnet( - inverted_residual_setting, - 0.4, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1)) -def efficientnet_b6( - *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B6_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6) - return _efficientnet( - inverted_residual_setting, - 0.5, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1)) -def efficientnet_b7( - *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_B7_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1) - return _efficientnet( - inverted_residual_setting, - 0.5, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) -def efficientnet_v2_s( - *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_V2_S_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s") - return _efficientnet( - inverted_residual_setting, - 0.2, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=1e-03), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) -def efficientnet_v2_m( - *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_V2_M_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m") - return _efficientnet( - inverted_residual_setting, - 0.3, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=1e-03), - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) -def efficientnet_v2_l( - *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any -) -> EfficientNet: - weights = EfficientNet_V2_L_Weights.verify(weights) - - inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l") - return _efficientnet( - inverted_residual_setting, - 0.4, - last_channel, - weights, - progress, - norm_layer=partial(nn.BatchNorm2d, eps=1e-03), - **kwargs, - ) diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 70dc0d9db5c..b50d70e1694 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -6,8 +6,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index eec78a26236..f8e82e9d78c 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index c48e34a7be5..142ca35002d 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.mnasnet import MNASNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/mobilenet.py b/torchvision/prototype/models/mobilenet.py deleted file mode 100644 index 0a270d14d3a..00000000000 --- a/torchvision/prototype/models/mobilenet.py +++ /dev/null @@ -1,6 +0,0 @@ -from .mobilenetv2 import * # noqa: F401, F403 -from .mobilenetv3 import * # noqa: F401, F403 -from .mobilenetv2 import __all__ as mv2_all -from .mobilenetv3 import __all__ as mv3_all - -__all__ = mv2_all + mv3_all diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 71b412898fe..babc455981c 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv2 import MobileNetV2 -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index aaf9c2c85a4..f5858f44aea 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 24e87f3d4f9..327764f5107 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -7,8 +7,8 @@ from torchvision.prototype.transforms import OpticalFlowEval from torchvision.transforms.functional import InterpolationMode -from .._api import WeightsEnum -from .._api import Weights +from torchvision.models._api import WeightsEnum +from torchvision.models._api import Weights from .._utils import handle_legacy_interface diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index cca6ba25060..12a21f973f6 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -10,8 +10,8 @@ _replace_relu, quantize_model, ) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param from ..googlenet import GoogLeNet_Weights diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index 2639b7de14f..d3a5d59fd01 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -9,8 +9,8 @@ _replace_relu, quantize_model, ) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param from ..inception import Inception_V3_Weights diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index a9789583fe6..271c7440a33 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -10,8 +10,8 @@ _replace_relu, quantize_model, ) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param from ..mobilenetv2 import MobileNet_V2_Weights diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index 915308d948f..482618b530c 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -11,8 +11,8 @@ QuantizableMobileNetV3, _replace_relu, ) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 9e2e29db0bf..5bd94567d3a 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -11,8 +11,8 @@ _replace_relu, quantize_model, ) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index e21349ff8e0..ffc516db418 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -9,8 +9,8 @@ _replace_relu, quantize_model, ) -from .._api import WeightsEnum, Weights -from .._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index d5e2b535532..5129299d138 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -6,8 +6,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.regnet import RegNet, BlockParams -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index 35e30c0e760..d96eac5d1bb 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.resnet import BasicBlock, Bottleneck, ResNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 7165078161f..dfd067974a3 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet -from .._api import WeightsEnum, Weights -from .._meta import _VOC_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _VOC_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..resnet import resnet50, resnet101 diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 1dfc251844f..13c197888e7 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.fcn import FCN, _fcn_resnet -from .._api import WeightsEnum, Weights -from .._meta import _VOC_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _VOC_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 2c0fa6f0aff..75a827c3610 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 -from .._api import WeightsEnum, Weights -from .._meta import _VOC_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _VOC_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index 48047a70c60..617603759bd 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.shufflenetv2 import ShuffleNetV2 -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index 7f6a034ed6c..b3c8817e7be 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.squeezenet import SqueezeNet -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index 233c35418ed..90a6b06ad39 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -5,8 +5,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.vgg import VGG, make_layers, cfgs -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 790d254d266..21bbb4d740d 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -15,8 +15,8 @@ R2Plus1dStem, VideoResNet, ) -from .._api import WeightsEnum, Weights -from .._meta import _KINETICS400_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _KINETICS400_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 468903b6b94..3fc34b72be5 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -9,8 +9,8 @@ from torchvision.transforms.functional import InterpolationMode from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 -from ._api import WeightsEnum, Weights -from ._meta import _IMAGENET_CATEGORIES +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 16369428e47..af88ca17b74 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,5 +1,3 @@ -from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip - from . import functional # usort: skip from ._transform import Transform # usort: skip @@ -10,11 +8,4 @@ from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda -from ._presets import ( - ObjectDetectionEval, - ImageClassificationEval, - SemanticSegmentationEval, - VideoClassificationEval, - OpticalFlowEval, -) from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 77680a14f0d..94ec34ebe98 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1,2 +1,9 @@ from .transforms import * from .autoaugment import * +from ._presets import ( + ObjectDetectionEval, + ImageClassificationEval, + SemanticSegmentationEval, + VideoClassificationEval, + OpticalFlowEval, +) diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/transforms/_presets.py similarity index 98% rename from torchvision/prototype/transforms/_presets.py rename to torchvision/transforms/_presets.py index 3ab045b3ddb..a6b85d05597 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn -from ...transforms import functional as F, InterpolationMode +from . import functional as F, InterpolationMode __all__ = [ From 78f90ceaa85f4723041c509d4454f69b142471bb Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 15:26:00 +0000 Subject: [PATCH 02/45] Porting googlenet --- torchvision/models/alexnet.py | 4 +- torchvision/models/googlenet.py | 74 +++++++++++++++-------- torchvision/prototype/models/googlenet.py | 63 ------------------- 3 files changed, 52 insertions(+), 89 deletions(-) delete mode 100644 torchvision/prototype/models/googlenet.py diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index ba8d6eef37a..0dbc62fbdee 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -9,8 +9,8 @@ from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param -from torchvision.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode +from ..transforms import ImageClassificationEval, InterpolationMode + __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 9e4c3498aab..d700e33a0e0 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -1,4 +1,5 @@ import warnings +from functools import partial from collections import namedtuple from typing import Optional, Tuple, List, Callable, Any @@ -7,15 +8,16 @@ import torch.nn.functional as F from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..utils import _log_api_usage_once +from ..transforms import ImageClassificationEval, InterpolationMode -__all__ = ["GoogLeNet", "googlenet", "GoogLeNetOutputs", "_GoogLeNetOutputs"] +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param + + +__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] -model_urls = { - # GoogLeNet ported from TensorFlow - "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", -} GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]) GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]} @@ -273,38 +275,62 @@ def forward(self, x: Tensor) -> Tensor: return F.relu(x, inplace=True) -def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> GoogLeNet: +class GoogLeNet_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/googlenet-1378be20.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "GoogLeNet", + "publication_year": 2014, + "num_params": 6624904, + "size": (224, 224), + "min_size": (15, 15), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet", + "acc@1": 69.778, + "acc@5": 89.530, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1)) +def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: r"""GoogLeNet (Inception v1) model architecture from `"Going Deeper with Convolutions" `_. The required minimum input size of the model is 15x15. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (GoogLeNet_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, adds two auxiliary branches that can improve training. Default: *False* when pretrained is True otherwise *True* transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = GoogLeNet_Weights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" not in kwargs: - kwargs["aux_logits"] = False - if kwargs["aux_logits"]: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - kwargs["init_weights"] = False - model = GoogLeNet(**kwargs) - state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress) - model.load_state_dict(state_dict) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = GoogLeNet(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) if not original_aux_logits: model.aux_logits = False model.aux1 = None # type: ignore[assignment] model.aux2 = None # type: ignore[assignment] - return model + else: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) - return GoogLeNet(**kwargs) + return model \ No newline at end of file diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py deleted file mode 100644 index b50d70e1694..00000000000 --- a/torchvision/prototype/models/googlenet.py +++ /dev/null @@ -1,63 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] - - -class GoogLeNet_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/googlenet-1378be20.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "GoogLeNet", - "publication_year": 2014, - "num_params": 6624904, - "size": (224, 224), - "min_size": (15, 15), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet", - "acc@1": 69.778, - "acc@5": 89.530, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1)) -def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: - weights = GoogLeNet_Weights.verify(weights) - - original_aux_logits = kwargs.get("aux_logits", False) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "init_weights", False) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = GoogLeNet(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not original_aux_logits: - model.aux_logits = False - model.aux1 = None # type: ignore[assignment] - model.aux2 = None # type: ignore[assignment] - else: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - - return model From 548d9ca5dfd8de7c3d9650f9b49a3d3d5376f8be Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 15:56:48 +0000 Subject: [PATCH 03/45] Porting inception --- torchvision/models/inception.py | 66 +++++++++++++++-------- torchvision/prototype/models/inception.py | 57 -------------------- 2 files changed, 45 insertions(+), 78 deletions(-) delete mode 100644 torchvision/prototype/models/inception.py diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index c489925cb45..ff6cf95dd31 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -1,22 +1,23 @@ import warnings from collections import namedtuple +from functools import partial from typing import Callable, Any, Optional, Tuple, List import torch import torch.nn.functional as F from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..utils import _log_api_usage_once -__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"] +from ..transforms import ImageClassificationEval, InterpolationMode +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -model_urls = { - # Inception v3 ported from TensorFlow - "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", -} +__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] + InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"]) InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]} @@ -407,7 +408,29 @@ def forward(self, x: Tensor) -> Tensor: return F.relu(x, inplace=True) -def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Inception3: +class Inception_V3_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", + transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + meta={ + "task": "image_classification", + "architecture": "InceptionV3", + "publication_year": 2015, + "num_params": 27161264, + "size": (299, 299), + "min_size": (75, 75), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3", + "acc@1": 77.294, + "acc@5": 93.450, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1)) +def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: r"""Inception v3 model architecture from `"Rethinking the Inception Architecture for Computer Vision" `_. The required minimum input size of the model is 75x75. @@ -417,28 +440,29 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) N x 3 x 299 x 299, so ensure your images are sized accordingly. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Inception_V3_Weights): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True* transform_input (bool): If True, preprocesses the input according to the method with which it was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = Inception_V3_Weights.verify(weights) + + original_aux_logits = kwargs.get("aux_logits", True) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" in kwargs: - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - else: - original_aux_logits = True - kwargs["init_weights"] = False # we are loading weights from a pretrained model - model = Inception3(**kwargs) - state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress) - model.load_state_dict(state_dict) + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = Inception3(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) if not original_aux_logits: model.aux_logits = False model.AuxLogits = None - return model - return Inception3(**kwargs) + return model \ No newline at end of file diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py deleted file mode 100644 index f8e82e9d78c..00000000000 --- a/torchvision/prototype/models/inception.py +++ /dev/null @@ -1,57 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] - - -class Inception_V3_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), - meta={ - "task": "image_classification", - "architecture": "InceptionV3", - "publication_year": 2015, - "num_params": 27161264, - "size": (299, 299), - "min_size": (75, 75), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3", - "acc@1": 77.294, - "acc@5": 93.450, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1)) -def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: - weights = Inception_V3_Weights.verify(weights) - - original_aux_logits = kwargs.get("aux_logits", True) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "init_weights", False) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = Inception3(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - - return model From 46ff3487f40e45e65ee631a94b6f55bcbb7b2519 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 16:04:24 +0000 Subject: [PATCH 04/45] Porting mnasnet --- torchvision/models/alexnet.py | 2 +- torchvision/models/convnext.py | 8 +- torchvision/models/efficientnet.py | 22 ++-- torchvision/models/googlenet.py | 2 +- torchvision/models/inception.py | 2 +- torchvision/models/mnasnet.py | 140 +++++++++++++++++------- torchvision/prototype/models/mnasnet.py | 113 ------------------- 7 files changed, 120 insertions(+), 169 deletions(-) delete mode 100644 torchvision/prototype/models/mnasnet.py diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 0dbc62fbdee..a86f091dfa2 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -81,7 +81,7 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, The required minimum input size of the model is 63x63. Args: - weights (AlexNet_Weights): The pretrained weights for the model + weights (AlexNet_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = AlexNet_Weights.verify(weights) diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 40811841b53..2717f9bf8e1 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -278,7 +278,7 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: r"""ConvNeXt Tiny model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - weights (ConvNeXt_Tiny_Weights): The pretrained weights for the model + weights (ConvNeXt_Tiny_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = ConvNeXt_Tiny_Weights.verify(weights) @@ -300,7 +300,7 @@ def convnext_small( r"""ConvNeXt Small model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - weights (ConvNeXt_Small_Weights): The pretrained weights for the model + weights (ConvNeXt_Small_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = ConvNeXt_Small_Weights.verify(weights) @@ -320,7 +320,7 @@ def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: r"""ConvNeXt Base model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - weights (ConvNeXt_Base_Weights): The pretrained weights for the model + weights (ConvNeXt_Base_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = ConvNeXt_Base_Weights.verify(weights) @@ -342,7 +342,7 @@ def convnext_large( r"""ConvNeXt Large model architecture from the `"A ConvNet for the 2020s" `_ paper. Args: - weights (ConvNeXt_Large_Weights): The pretrained weights for the model + weights (ConvNeXt_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = ConvNeXt_Large_Weights.verify(weights) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 21e19c7d01a..f76fa4f5477 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -677,7 +677,7 @@ def efficientnet_b0( `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - weights (EfficientNet_B0_Weights): The pretrained weights for the model + weights (EfficientNet_B0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_B0_Weights.verify(weights) @@ -695,7 +695,7 @@ def efficientnet_b1( `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - weights (EfficientNet_B1_Weights): The pretrained weights for the model + weights (EfficientNet_B1_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_B1_Weights.verify(weights) @@ -713,7 +713,7 @@ def efficientnet_b2( `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - weights (EfficientNet_B2_Weights): The pretrained weights for the model + weights (EfficientNet_B2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_B2_Weights.verify(weights) @@ -731,7 +731,7 @@ def efficientnet_b3( `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - weights (EfficientNet_B3_Weights): The pretrained weights for the model + weights (EfficientNet_B3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_B3_Weights.verify(weights) @@ -749,7 +749,7 @@ def efficientnet_b4( `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - weights (EfficientNet_B4_Weights): The pretrained weights for the model + weights (EfficientNet_B4_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_B4_Weights.verify(weights) @@ -767,7 +767,7 @@ def efficientnet_b5( `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - weights (EfficientNet_B5_Weights): The pretrained weights for the model + weights (EfficientNet_B5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_B5_Weights.verify(weights) @@ -793,7 +793,7 @@ def efficientnet_b6( `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - weights (EfficientNet_B6_Weights): The pretrained weights for the model + weights (EfficientNet_B6_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_B6_Weights.verify(weights) @@ -819,7 +819,7 @@ def efficientnet_b7( `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. Args: - weights (EfficientNet_B7_Weights): The pretrained weights for the model + weights (EfficientNet_B7_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_B7_Weights.verify(weights) @@ -845,7 +845,7 @@ def efficientnet_v2_s( `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - weights (EfficientNet_V2_S_Weights): The pretrained weights for the model + weights (EfficientNet_V2_S_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_V2_S_Weights.verify(weights) @@ -871,7 +871,7 @@ def efficientnet_v2_m( `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - weights (EfficientNet_V2_M_Weights): The pretrained weights for the model + weights (EfficientNet_V2_M_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_V2_M_Weights.verify(weights) @@ -897,7 +897,7 @@ def efficientnet_v2_l( `"EfficientNetV2: Smaller Models and Faster Training" `_. Args: - weights (EfficientNet_V2_L_Weights): The pretrained weights for the model + weights (EfficientNet_V2_L_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ weights = EfficientNet_V2_L_Weights.verify(weights) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index d700e33a0e0..daad4ebb8df 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -303,7 +303,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T The required minimum input size of the model is 15x15. Args: - weights (GoogLeNet_Weights): The pretrained weights for the model + weights (GoogLeNet_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, adds two auxiliary branches that can improve training. Default: *False* when pretrained is True otherwise *True* diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index ff6cf95dd31..bdac024406c 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -440,7 +440,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo N x 3 x 299 x 299, so ensure your images are sized accordingly. Args: - weights (Inception_V3_Weights): The pretrained weights for the model + weights (Inception_V3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True* diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index c3d4013f30c..1634ea2e80b 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -1,21 +1,30 @@ +from functools import partial import warnings -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch import torch.nn as nn from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["MNASNet", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"] -_MODEL_URLS = { - "mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - "mnasnet0_75": None, - "mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - "mnasnet1_3": None, -} +__all__ = [ + "MNASNet", + "MNASNet0_5_Weights", + "MNASNet0_75_Weights", + "MNASNet1_0_Weights", + "MNASNet1_3_Weights", + "mnasnet0_5", + "mnasnet0_75", + "mnasnet1_0", + "mnasnet1_3", +] + # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # 1.0 - tensorflow. @@ -196,68 +205,123 @@ def _load_from_state_dict( ) -def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None: - if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: - raise ValueError(f"No checkpoint is available for model type {model_name}") - checkpoint_url = _MODEL_URLS[model_name] - model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) +_COMMON_META = { + "task": "image_classification", + "architecture": "MNASNet", + "publication_year": 2018, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/1e100/mnasnet_trainer", +} + + +class MNASNet0_5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2218512, + "acc@1": 67.734, + "acc@5": 87.490, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MNASNet0_75_Weights(WeightsEnum): + # If a default model is added here the corresponding changes need to be done in mnasnet0_75 + pass + + +class MNASNet1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 4383312, + "acc@1": 73.456, + "acc@5": 91.510, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class MNASNet1_3_Weights(WeightsEnum): + # If a default model is added here the corresponding changes need to be done in mnasnet1_3 + pass + + +def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = MNASNet(alpha, **kwargs) -def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model + + +@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1)) +def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 0.5 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet0_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(0.5, **kwargs) - if pretrained: - _load_pretrained("mnasnet0_5", model, progress) - return model + weights = MNASNet0_5_Weights.verify(weights) + + return _mnasnet(0.5, weights, progress, **kwargs) -def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: +@handle_legacy_interface(weights=("pretrained", None)) +def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 0.75 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet0_75_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(0.75, **kwargs) - if pretrained: - _load_pretrained("mnasnet0_75", model, progress) - return model + weights = MNASNet0_75_Weights.verify(weights) + + return _mnasnet(0.75, weights, progress, **kwargs) -def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: +@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1)) +def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 1.0 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(1.0, **kwargs) - if pretrained: - _load_pretrained("mnasnet1_0", model, progress) - return model + weights = MNASNet1_0_Weights.verify(weights) + return _mnasnet(1.0, weights, progress, **kwargs) -def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: + +@handle_legacy_interface(weights=("pretrained", None)) +def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: r"""MNASNet with depth multiplier of 1.3 from `"MnasNet: Platform-Aware Neural Architecture Search for Mobile" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MNASNet1_3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - model = MNASNet(1.3, **kwargs) - if pretrained: - _load_pretrained("mnasnet1_3", model, progress) - return model + weights = MNASNet1_3_Weights.verify(weights) + + return _mnasnet(1.3, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py deleted file mode 100644 index 142ca35002d..00000000000 --- a/torchvision/prototype/models/mnasnet.py +++ /dev/null @@ -1,113 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.mnasnet import MNASNet -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "MNASNet", - "MNASNet0_5_Weights", - "MNASNet0_75_Weights", - "MNASNet1_0_Weights", - "MNASNet1_3_Weights", - "mnasnet0_5", - "mnasnet0_75", - "mnasnet1_0", - "mnasnet1_3", -] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "MNASNet", - "publication_year": 2018, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/1e100/mnasnet_trainer", -} - - -class MNASNet0_5_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2218512, - "acc@1": 67.734, - "acc@5": 87.490, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class MNASNet0_75_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet0_75 - pass - - -class MNASNet1_0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 4383312, - "acc@1": 73.456, - "acc@5": 91.510, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class MNASNet1_3_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet1_3 - pass - - -def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = MNASNet(alpha, **kwargs) - - if weights: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1)) -def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet0_5_Weights.verify(weights) - - return _mnasnet(0.5, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet0_75_Weights.verify(weights) - - return _mnasnet(0.75, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1)) -def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet1_0_Weights.verify(weights) - - return _mnasnet(1.0, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: - weights = MNASNet1_3_Weights.verify(weights) - - return _mnasnet(1.3, weights, progress, **kwargs) From 58be05b25930164f550eccc09f879aee72c49762 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 16:08:30 +0000 Subject: [PATCH 05/45] Porting mobilenetv2 --- torchvision/models/mobilenetv2.py | 68 +++++++++++++++++---- torchvision/prototype/models/mobilenetv2.py | 66 -------------------- 2 files changed, 56 insertions(+), 78 deletions(-) delete mode 100644 torchvision/prototype/models/mobilenetv2.py diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 930f68d13e9..7598338ebd9 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,3 +1,4 @@ +from functools import partial import warnings from typing import Callable, Any, Optional, List @@ -5,18 +6,16 @@ from torch import Tensor from torch import nn -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible -__all__ = ["MobileNetV2", "mobilenet_v2"] - -model_urls = { - "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", -} +__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] # necessary for backwards compatibility @@ -195,17 +194,62 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: +_COMMON_META = { + "task": "image_classification", + "architecture": "MobileNetV2", + "publication_year": 2018, + "num_params": 3504872, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class MobileNet_V2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", + "acc@1": 71.878, + "acc@5": 90.286, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "acc@1": 72.154, + "acc@5": 90.822, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) +def mobilenet_v2( + *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV2: """ Constructs a MobileNetV2 architecture from `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MobileNet_V2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = MobileNet_V2_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = MobileNetV2(**kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["mobilenet_v2"], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py deleted file mode 100644 index babc455981c..00000000000 --- a/torchvision/prototype/models/mobilenetv2.py +++ /dev/null @@ -1,66 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.mobilenetv2 import MobileNetV2 -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "MobileNetV2", - "publication_year": 2018, - "num_params": 3504872, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class MobileNet_V2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", - "acc@1": 71.878, - "acc@5": 90.286, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", - "acc@1": 72.154, - "acc@5": 90.822, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) -def mobilenet_v2( - *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any -) -> MobileNetV2: - weights = MobileNet_V2_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = MobileNetV2(**kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model From 99930d5ddb4cd72db48cefad866e3b237889f1ea Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 16:14:54 +0000 Subject: [PATCH 06/45] Porting mobilenetv3 --- torchvision/models/mobilenetv3.py | 116 +++++++++++++++----- torchvision/prototype/models/mobilenetv3.py | 109 ------------------ 2 files changed, 90 insertions(+), 135 deletions(-) delete mode 100644 torchvision/prototype/models/mobilenetv3.py diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 530467d6d53..031b9d761f0 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -5,19 +5,22 @@ import torch from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible -__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] - -model_urls = { - "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", -} +__all__ = [ + "MobileNetV3", + "MobileNet_V3_Large_Weights", + "MobileNet_V3_Small_Weights", + "mobilenet_v3_large", + "mobilenet_v3_small", +] class SqueezeExcitation(SElayer): @@ -284,45 +287,106 @@ def _mobilenet_v3_conf( def _mobilenet_v3( - arch: str, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, -): +) -> MobileNetV3: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) - if pretrained: - if model_urls.get(arch, None) is None: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: +_COMMON_META = { + "task": "image_classification", + "architecture": "MobileNetV3", + "publication_year": 2019, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class MobileNet_V3_Large_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 5483032, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", + "acc@1": 74.042, + "acc@5": 91.340, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 5483032, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", + "acc@1": 75.274, + "acc@5": 92.566, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class MobileNet_V3_Small_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2542856, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", + "acc@1": 67.668, + "acc@5": 87.402, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) +def mobilenet_v3_large( + *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV3: """ Constructs a large MobileNetV3 architecture from `"Searching for MobileNetV3" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "mobilenet_v3_large" - inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + weights = MobileNet_V3_Large_Weights.verify(weights) + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) -def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: + +@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) +def mobilenet_v3_small( + *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV3: """ Constructs a small MobileNetV3 architecture from `"Searching for MobileNetV3" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (MobileNet_V3_Small_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - arch = "mobilenet_v3_small" - inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + weights = MobileNet_V3_Small_Weights.verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) + return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py deleted file mode 100644 index f5858f44aea..00000000000 --- a/torchvision/prototype/models/mobilenetv3.py +++ /dev/null @@ -1,109 +0,0 @@ -from functools import partial -from typing import Any, Optional, List - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "MobileNetV3", - "MobileNet_V3_Large_Weights", - "MobileNet_V3_Small_Weights", - "mobilenet_v3_large", - "mobilenet_v3_small", -] - - -def _mobilenet_v3( - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> MobileNetV3: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "MobileNetV3", - "publication_year": 2019, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class MobileNet_V3_Large_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 5483032, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", - "acc@1": 74.042, - "acc@5": 91.340, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 5483032, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", - "acc@1": 75.274, - "acc@5": 92.566, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class MobileNet_V3_Small_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2542856, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", - "acc@1": 67.668, - "acc@5": 87.402, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1)) -def mobilenet_v3_large( - *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any -) -> MobileNetV3: - weights = MobileNet_V3_Large_Weights.verify(weights) - - inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) - return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1)) -def mobilenet_v3_small( - *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any -) -> MobileNetV3: - weights = MobileNet_V3_Small_Weights.verify(weights) - - inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) - return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) From 9b4a3efc22c84b169065d89923f02e16f1e2c54d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 16:39:45 +0000 Subject: [PATCH 07/45] Porting regnet --- torchvision/models/regnet.py | 587 +++++++++++++++++++++---- torchvision/prototype/models/regnet.py | 575 ------------------------ 2 files changed, 509 insertions(+), 653 deletions(-) delete mode 100644 torchvision/prototype/models/regnet.py diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 74abd20b237..21727ecff71 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -1,8 +1,3 @@ -# Modified from -# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/anynet.py -# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py - - import math from collections import OrderedDict from functools import partial @@ -11,14 +6,32 @@ import torch from torch import nn, Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation, SqueezeExcitation +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once -from ._utils import _make_divisible + +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible __all__ = [ "RegNet", + "RegNet_Y_400MF_Weights", + "RegNet_Y_800MF_Weights", + "RegNet_Y_1_6GF_Weights", + "RegNet_Y_3_2GF_Weights", + "RegNet_Y_8GF_Weights", + "RegNet_Y_16GF_Weights", + "RegNet_Y_32GF_Weights", + "RegNet_Y_128GF_Weights", + "RegNet_X_400MF_Weights", + "RegNet_X_800MF_Weights", + "RegNet_X_1_6GF_Weights", + "RegNet_X_3_2GF_Weights", + "RegNet_X_8GF_Weights", + "RegNet_X_16GF_Weights", + "RegNet_X_32GF_Weights", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", @@ -37,24 +50,6 @@ ] -model_urls = { - "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", -} - - class SimpleStemIN(Conv2dNormActivation): """Simple stem for ImageNet: 3x3, BN, ReLU.""" @@ -390,219 +385,655 @@ def forward(self, x: Tensor) -> Tensor: return x -def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet: +def _regnet( + block_params: BlockParams, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> RegNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) model = RegNet(block_params, norm_layer=norm_layer, **kwargs) - if pretrained: - if arch not in model_urls: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +_COMMON_META = { + "task": "image_classification", + "architecture": "RegNet", + "publication_year": 2020, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class RegNet_Y_400MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 4344144, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 74.046, + "acc@5": 91.716, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 4344144, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 75.804, + "acc@5": 92.742, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_800MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 6432512, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 76.420, + "acc@5": 93.136, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 6432512, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 78.828, + "acc@5": 94.502, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_1_6GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 11202430, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 77.950, + "acc@5": 93.966, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 11202430, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 80.876, + "acc@5": 95.444, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_3_2GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 19436338, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 78.948, + "acc@5": 94.576, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 19436338, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.982, + "acc@5": 95.972, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_8GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 39381472, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 80.032, + "acc@5": 95.048, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 39381472, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.828, + "acc@5": 96.330, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_16GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 83590140, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "acc@1": 80.424, + "acc@5": 95.240, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 83590140, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.886, + "acc@5": 96.328, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_32GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 145046770, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "acc@1": 80.878, + "acc@5": 95.340, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 145046770, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 83.368, + "acc@5": 96.498, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_Y_128GF_Weights(WeightsEnum): + # weights are not available yet. + pass + + +class RegNet_X_400MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 5495976, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 72.834, + "acc@5": 90.950, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 5495976, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 74.864, + "acc@5": 92.322, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_800MF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 7259656, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 75.212, + "acc@5": 92.348, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 7259656, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 77.522, + "acc@5": 93.826, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_1_6GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 9190136, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", + "acc@1": 77.040, + "acc@5": 93.440, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 9190136, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 79.668, + "acc@5": 94.922, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_3_2GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 15296552, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 78.364, + "acc@5": 93.992, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 15296552, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.196, + "acc@5": 95.430, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_8GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 39572648, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 79.344, + "acc@5": 94.686, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 39572648, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.682, + "acc@5": 95.678, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_16GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 54278536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", + "acc@1": 80.058, + "acc@5": 94.944, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 54278536, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.716, + "acc@5": 96.196, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class RegNet_X_32GF_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 107811560, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", + "acc@1": 80.622, + "acc@5": 95.248, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 107811560, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 83.014, + "acc@5": 96.288, + }, + ) + DEFAULT = IMAGENET1K_V2 + + + + + +@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1)) +def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_400MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_400MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_400MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) - return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1)) +def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_800MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_800MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_800MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) - return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1)) +def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_1.6GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_1_6GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_1_6GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1)) +def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_3.2GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_3_2GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_3_2GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1)) +def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_8GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_8GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_8GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1)) +def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_16GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_16GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_16GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1)) +def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_32GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_Y_32GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_32GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", None)) +def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetY_128GF architecture from `"Designing Network Design Spaces" `_. NOTE: Pretrained weights are not available for this model. + + Args: + weights (RegNet_Y_128GF_Weights, optional): The pretrained weights for the model + progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_Y_128GF_Weights.verify(weights) + params = BlockParams.from_init_params( depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs ) - return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1)) +def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_400MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_400MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_400MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) - return _regnet("regnet_x_400mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1)) +def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_800MF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_800MF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_800MF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) - return _regnet("regnet_x_800mf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1)) +def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_1.6GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_1_6GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) - return _regnet("regnet_x_1_6gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1)) +def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_3.2GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_3_2GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) - return _regnet("regnet_x_3_2gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1)) +def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_8GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_8GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) - return _regnet("regnet_x_8gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1)) +def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_16GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = RegNet_X_16GF_Weights.verify(weights) + params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) - return _regnet("regnet_x_16gf", params, pretrained, progress, **kwargs) + return _regnet(params, weights, progress, **kwargs) -def regnet_x_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: +@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1)) +def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ Constructs a RegNetX_32GF architecture from `"Designing Network Design Spaces" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) - return _regnet("regnet_x_32gf", params, pretrained, progress, **kwargs) + weights = RegNet_X_32GF_Weights.verify(weights) - -# TODO(kazhang): Add RegNetZ_500MF and RegNetZ_4GF + params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) + return _regnet(params, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py deleted file mode 100644 index 5129299d138..00000000000 --- a/torchvision/prototype/models/regnet.py +++ /dev/null @@ -1,575 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torch import nn -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.regnet import RegNet, BlockParams -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "RegNet", - "RegNet_Y_400MF_Weights", - "RegNet_Y_800MF_Weights", - "RegNet_Y_1_6GF_Weights", - "RegNet_Y_3_2GF_Weights", - "RegNet_Y_8GF_Weights", - "RegNet_Y_16GF_Weights", - "RegNet_Y_32GF_Weights", - "RegNet_Y_128GF_Weights", - "RegNet_X_400MF_Weights", - "RegNet_X_800MF_Weights", - "RegNet_X_1_6GF_Weights", - "RegNet_X_3_2GF_Weights", - "RegNet_X_8GF_Weights", - "RegNet_X_16GF_Weights", - "RegNet_X_32GF_Weights", - "regnet_y_400mf", - "regnet_y_800mf", - "regnet_y_1_6gf", - "regnet_y_3_2gf", - "regnet_y_8gf", - "regnet_y_16gf", - "regnet_y_32gf", - "regnet_y_128gf", - "regnet_x_400mf", - "regnet_x_800mf", - "regnet_x_1_6gf", - "regnet_x_3_2gf", - "regnet_x_8gf", - "regnet_x_16gf", - "regnet_x_32gf", -] - -_COMMON_META = { - "task": "image_classification", - "architecture": "RegNet", - "publication_year": 2020, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -def _regnet( - block_params: BlockParams, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> RegNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) - model = RegNet(block_params, norm_layer=norm_layer, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -class RegNet_Y_400MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 4344144, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 74.046, - "acc@5": 91.716, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 4344144, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 75.804, - "acc@5": 92.742, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_800MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 6432512, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 76.420, - "acc@5": 93.136, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 6432512, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 78.828, - "acc@5": 94.502, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_1_6GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 11202430, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 77.950, - "acc@5": 93.966, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 11202430, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 80.876, - "acc@5": 95.444, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_3_2GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 19436338, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 78.948, - "acc@5": 94.576, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 19436338, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.982, - "acc@5": 95.972, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_8GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 39381472, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 80.032, - "acc@5": 95.048, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 39381472, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.828, - "acc@5": 96.330, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_16GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 83590140, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", - "acc@1": 80.424, - "acc@5": 95.240, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 83590140, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.886, - "acc@5": 96.328, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_32GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 145046770, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", - "acc@1": 80.878, - "acc@5": 95.340, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 145046770, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 83.368, - "acc@5": 96.498, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_Y_128GF_Weights(WeightsEnum): - # weights are not available yet. - pass - - -class RegNet_X_400MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 5495976, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 72.834, - "acc@5": 90.950, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 5495976, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 74.864, - "acc@5": 92.322, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_800MF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 7259656, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 75.212, - "acc@5": 92.348, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 7259656, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 77.522, - "acc@5": 93.826, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_1_6GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 9190136, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models", - "acc@1": 77.040, - "acc@5": 93.440, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 9190136, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 79.668, - "acc@5": 94.922, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_3_2GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 15296552, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 78.364, - "acc@5": 93.992, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 15296552, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.196, - "acc@5": 95.430, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_8GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 39572648, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 79.344, - "acc@5": 94.686, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 39572648, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.682, - "acc@5": 95.678, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_16GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 54278536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models", - "acc@1": 80.058, - "acc@5": 94.944, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 54278536, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.716, - "acc@5": 96.196, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class RegNet_X_32GF_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 107811560, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models", - "acc@1": 80.622, - "acc@5": 95.248, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "num_params": 107811560, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 83.014, - "acc@5": 96.288, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1)) -def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_400MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1)) -def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_800MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1)) -def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_1_6GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1)) -def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_3_2GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1)) -def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_8GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1)) -def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_16GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1)) -def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_32GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_Y_128GF_Weights.verify(weights) - - params = BlockParams.from_init_params( - depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs - ) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1)) -def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_400MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1)) -def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_800MF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1)) -def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_1_6GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1)) -def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_3_2GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1)) -def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_8GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1)) -def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_16GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) - return _regnet(params, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1)) -def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: - weights = RegNet_X_32GF_Weights.verify(weights) - - params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) - return _regnet(params, weights, progress, **kwargs) From 3788fb88fd824966223f12bb3aef0f5567029611 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 16:54:31 +0000 Subject: [PATCH 08/45] Porting resnet --- torchvision/models/resnet.py | 395 +++++++++++++++++++++---- torchvision/prototype/models/resnet.py | 381 ------------------------ 2 files changed, 343 insertions(+), 433 deletions(-) delete mode 100644 torchvision/prototype/models/resnet.py diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index b0bb8d13ade..6391b0a0705 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,15 +1,28 @@ +from functools import partial from typing import Type, Any, Callable, Union, List, Optional import torch import torch.nn as nn from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url from ..utils import _log_api_usage_once +from ..transforms import ImageClassificationEval, InterpolationMode +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param __all__ = [ "ResNet", + "ResNet18_Weights", + "ResNet34_Weights", + "ResNet50_Weights", + "ResNet101_Weights", + "ResNet152_Weights", + "ResNeXt50_32X4D_Weights", + "ResNeXt101_32X8D_Weights", + "Wide_ResNet50_2_Weights", + "Wide_ResNet101_2_Weights", "resnet18", "resnet34", "resnet50", @@ -22,19 +35,6 @@ ] -model_urls = { - "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", - "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", - "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", - "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", - "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", - "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", -} - - def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( @@ -284,102 +284,386 @@ def forward(self, x: Tensor) -> Tensor: def _resnet( - arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> ResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = ResNet(block, layers, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +_COMMON_META = { + "task": "image_classification", + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class ResNet18_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet18-f37072fd.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 11689512, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 69.758, + "acc@5": 89.078, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet34_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet34-b627a593.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 21797672, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 73.314, + "acc@5": 91.420, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet50_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet50-0676ba61.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 76.130, + "acc@5": 92.862, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621", + "acc@1": 80.858, + "acc@5": 95.434, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet101_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet101-63fe2227.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 77.374, + "acc@5": 93.546, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.886, + "acc@5": 95.780, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet152_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet152-394f9c45.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "acc@1": 78.312, + "acc@5": 94.046, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet152-f82ba261.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.284, + "acc@5": 96.002, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt50_32X4D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "acc@1": 77.618, + "acc@5": 93.698, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 81.198, + "acc@5": 95.340, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt101_32X8D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "acc@1": 79.312, + "acc@5": 94.526, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 82.834, + "acc@5": 96.228, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class Wide_ResNet50_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "acc@1": 78.468, + "acc@5": 94.086, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "acc@1": 81.602, + "acc@5": 95.758, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class Wide_ResNet101_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "acc@1": 78.848, + "acc@5": 94.284, + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "WideResNet", + "publication_year": 2016, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "acc@1": 82.510, + "acc@5": 96.020, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) +def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) + weights = ResNet18_Weights.verify(weights) + + return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) -def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) +def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet34_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = ResNet34_Weights.verify(weights) + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) +def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = ResNet50_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) +def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet101_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + weights = ResNet101_Weights.verify(weights) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) +def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNet152_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) + weights = ResNet152_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) -def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1)) +def resnext50_32x4d( + *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNeXt50_32X4D_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 4 - return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = ResNeXt50_32X4D_Weights.verify(weights) + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 4) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1)) +def resnext101_32x8d( + *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ResNeXt101_32X8D_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 - return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + weights = ResNeXt101_32X8D_Weights.verify(weights) + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + +@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1)) +def wide_resnet50_2( + *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_. @@ -389,14 +673,19 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: A channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Wide_ResNet50_2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["width_per_group"] = 64 * 2 - return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) + weights = Wide_ResNet50_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: +@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1)) +def wide_resnet101_2( + *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_. @@ -406,8 +695,10 @@ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (Wide_ResNet101_2_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs["width_per_group"] = 64 * 2 - return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + weights = Wide_ResNet101_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py deleted file mode 100644 index d96eac5d1bb..00000000000 --- a/torchvision/prototype/models/resnet.py +++ /dev/null @@ -1,381 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Type, Union - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.resnet import BasicBlock, Bottleneck, ResNet -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "ResNet", - "ResNet18_Weights", - "ResNet34_Weights", - "ResNet50_Weights", - "ResNet101_Weights", - "ResNet152_Weights", - "ResNeXt50_32X4D_Weights", - "ResNeXt101_32X8D_Weights", - "Wide_ResNet50_2_Weights", - "Wide_ResNet101_2_Weights", - "resnet18", - "resnet34", - "resnet50", - "resnet101", - "resnet152", - "resnext50_32x4d", - "resnext101_32x8d", - "wide_resnet50_2", - "wide_resnet101_2", -] - - -def _resnet( - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> ResNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = ResNet(block, layers, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class ResNet18_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet18-f37072fd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 11689512, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 69.758, - "acc@5": 89.078, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ResNet34_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet34-b627a593.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 21797672, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 73.314, - "acc@5": 91.420, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ResNet50_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 76.130, - "acc@5": 92.862, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621", - "acc@1": 80.858, - "acc@5": 95.434, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNet101_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet101-63fe2227.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 44549160, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 77.374, - "acc@5": 93.546, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 44549160, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.886, - "acc@5": 95.780, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNet152_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnet152-394f9c45.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 60192808, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", - "acc@1": 78.312, - "acc@5": 94.046, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnet152-f82ba261.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 60192808, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.284, - "acc@5": 96.002, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNeXt50_32X4D_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 25028904, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", - "acc@1": 77.618, - "acc@5": 93.698, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 25028904, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 81.198, - "acc@5": 95.340, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class ResNeXt101_32X8D_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", - "acc@1": 79.312, - "acc@5": 94.526, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 82.834, - "acc@5": 96.228, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class Wide_ResNet50_2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 68883240, - "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", - "acc@1": 78.468, - "acc@5": 94.086, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 68883240, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", - "acc@1": 81.602, - "acc@5": 95.758, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -class Wide_ResNet101_2_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 126886696, - "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", - "acc@1": 78.848, - "acc@5": 94.284, - }, - ) - IMAGENET1K_V2 = Weights( - url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "WideResNet", - "publication_year": 2016, - "num_params": 126886696, - "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", - "acc@1": 82.510, - "acc@5": 96.020, - }, - ) - DEFAULT = IMAGENET1K_V2 - - -@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) -def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet18_Weights.verify(weights) - - return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) -def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet34_Weights.verify(weights) - - return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) -def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet50_Weights.verify(weights) - - return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) -def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet101_Weights.verify(weights) - - return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) -def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: - weights = ResNet152_Weights.verify(weights) - - return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1)) -def resnext50_32x4d( - *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = ResNeXt50_32X4D_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "groups", 32) - _ovewrite_named_param(kwargs, "width_per_group", 4) - return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1)) -def resnext101_32x8d( - *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = ResNeXt101_32X8D_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "groups", 32) - _ovewrite_named_param(kwargs, "width_per_group", 8) - return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1)) -def wide_resnet50_2( - *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = Wide_ResNet50_2_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) - return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1)) -def wide_resnet101_2( - *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any -) -> ResNet: - weights = Wide_ResNet101_2_Weights.verify(weights) - - _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) - return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) From 0eabb0c7cafca92fd0c3958daea3843d921a3026 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 11 Mar 2022 17:07:52 +0000 Subject: [PATCH 09/45] Porting shufflenetv2 --- torchvision/models/shufflenetv2.py | 139 +++++++++++++++---- torchvision/prototype/models/shufflenetv2.py | 124 ----------------- 2 files changed, 110 insertions(+), 153 deletions(-) delete mode 100644 torchvision/prototype/models/shufflenetv2.py diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index f3758c54aaf..8aa20f7c3fc 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -1,21 +1,30 @@ -from typing import Callable, Any, List +from functools import partial +from typing import Callable, Any, List, Optional import torch import torch.nn as nn from torch import Tensor -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["ShuffleNetV2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"] -model_urls = { - "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - "shufflenetv2_x1.5": None, - "shufflenetv2_x2.0": None, -} +__all__ = [ + "ShuffleNetV2", + "ShuffleNet_V2_X0_5_Weights", + "ShuffleNet_V2_X1_0_Weights", + "ShuffleNet_V2_X1_5_Weights", + "ShuffleNet_V2_X2_0_Weights", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", +] + def channel_shuffle(x: Tensor, groups: int) -> Tensor: @@ -156,67 +165,139 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2: +def _shufflenetv2( + weights: Optional[WeightsEnum], + progress: bool, + *args: Any, + **kwargs: Any, +) -> ShuffleNetV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = ShuffleNetV2(*args, **kwargs) - if pretrained: - model_url = model_urls[arch] - if model_url is None: - raise ValueError(f"No checkpoint is available for model type {arch}") - else: - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: +_COMMON_META = { + "task": "image_classification", + "architecture": "ShuffleNetV2", + "publication_year": 2018, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0", +} + + +class ShuffleNet_V2_X0_5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 1366792, + "acc@1": 69.362, + "acc@5": 88.316, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ShuffleNet_V2_X1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2278604, + "acc@1": 60.552, + "acc@5": 81.746, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ShuffleNet_V2_X1_5_Weights(WeightsEnum): + pass + + +class ShuffleNet_V2_X2_0_Weights(WeightsEnum): + pass + + + +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x0_5( + *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 0.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X0_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + weights = ShuffleNet_V2_X0_5_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) -def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: +@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)) +def shufflenet_v2_x1_0( + *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + weights = ShuffleNet_V2_X1_0_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) -def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: +@handle_legacy_interface(weights=("pretrained", None)) +def shufflenet_v2_x1_5( + *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 1.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X1_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + weights = ShuffleNet_V2_X1_5_Weights.verify(weights) + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) -def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: + +@handle_legacy_interface(weights=("pretrained", None)) +def shufflenet_v2_x2_0( + *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> ShuffleNetV2: """ Constructs a ShuffleNetV2 with 2.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ShuffleNet_V2_X2_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) + weights = ShuffleNet_V2_X2_0_Weights.verify(weights) + + return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py deleted file mode 100644 index 617603759bd..00000000000 --- a/torchvision/prototype/models/shufflenetv2.py +++ /dev/null @@ -1,124 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.shufflenetv2 import ShuffleNetV2 -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "ShuffleNetV2", - "ShuffleNet_V2_X0_5_Weights", - "ShuffleNet_V2_X1_0_Weights", - "ShuffleNet_V2_X1_5_Weights", - "ShuffleNet_V2_X2_0_Weights", - "shufflenet_v2_x0_5", - "shufflenet_v2_x1_0", - "shufflenet_v2_x1_5", - "shufflenet_v2_x2_0", -] - - -def _shufflenetv2( - weights: Optional[WeightsEnum], - progress: bool, - *args: Any, - **kwargs: Any, -) -> ShuffleNetV2: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = ShuffleNetV2(*args, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ShuffleNetV2", - "publication_year": 2018, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0", -} - - -class ShuffleNet_V2_X0_5_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 1366792, - "acc@1": 69.362, - "acc@5": 88.316, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ShuffleNet_V2_X1_0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2278604, - "acc@1": 60.552, - "acc@5": 81.746, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ShuffleNet_V2_X1_5_Weights(WeightsEnum): - pass - - -class ShuffleNet_V2_X2_0_Weights(WeightsEnum): - pass - - -@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1)) -def shufflenet_v2_x0_5( - *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X0_5_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) - - -@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)) -def shufflenet_v2_x1_0( - *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X1_0_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def shufflenet_v2_x1_5( - *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X1_5_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) - - -@handle_legacy_interface(weights=("pretrained", None)) -def shufflenet_v2_x2_0( - *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any -) -> ShuffleNetV2: - weights = ShuffleNet_V2_X2_0_Weights.verify(weights) - - return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) From 8c45c1a8e6187a14dac4e0becb07f9efb6f04291 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 10:10:58 +0000 Subject: [PATCH 10/45] Porting squeezenet --- torchvision/models/squeezenet.py | 94 +++++++++++++++++----- torchvision/prototype/models/squeezenet.py | 88 -------------------- 2 files changed, 76 insertions(+), 106 deletions(-) delete mode 100644 torchvision/prototype/models/squeezenet.py diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 2c1a30f225d..5d1d46f2856 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -1,18 +1,19 @@ -from typing import Any +from functools import partial +from typing import Any, Optional import torch import torch.nn as nn import torch.nn.init as init -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once -__all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"] +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param -model_urls = { - "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", -} + +__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] class Fire(nn.Module): @@ -97,29 +98,85 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.flatten(x, 1) -def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet: +def _squeezenet( + version: str, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> SqueezeNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = SqueezeNet(version, **kwargs) - if pretrained: - arch = "squeezenet" + version - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model -def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: +_COMMON_META = { + "task": "image_classification", + "architecture": "SqueezeNet", + "publication_year": 2016, + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717", +} + + +class SqueezeNet1_0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "min_size": (21, 21), + "num_params": 1248424, + "acc@1": 58.092, + "acc@5": 80.420, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class SqueezeNet1_1_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "min_size": (17, 17), + "num_params": 1235496, + "acc@1": 58.178, + "acc@5": 80.624, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1)) +def squeezenet1_0( + *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any +) -> SqueezeNet: r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size" `_ paper. The required minimum input size of the model is 21x21. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (SqueezeNet1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet("1_0", pretrained, progress, **kwargs) + weights = SqueezeNet1_0_Weights.verify(weights) + return _squeezenet("1_0", weights, progress, **kwargs) -def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: +@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1)) +def squeezenet1_1( + *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any +) -> SqueezeNet: r"""SqueezeNet 1.1 model from the `official SqueezeNet repo `_. SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters @@ -127,7 +184,8 @@ def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any The required minimum input size of the model is 17x17. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (SqueezeNet1_1_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet("1_1", pretrained, progress, **kwargs) + weights = SqueezeNet1_1_Weights.verify(weights) + return _squeezenet("1_1", weights, progress, **kwargs) diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py deleted file mode 100644 index b3c8817e7be..00000000000 --- a/torchvision/prototype/models/squeezenet.py +++ /dev/null @@ -1,88 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.squeezenet import SqueezeNet -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "SqueezeNet", - "publication_year": 2016, - "size": (224, 224), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717", -} - - -class SqueezeNet1_0_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "min_size": (21, 21), - "num_params": 1248424, - "acc@1": 58.092, - "acc@5": 80.420, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class SqueezeNet1_1_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "min_size": (17, 17), - "num_params": 1235496, - "acc@1": 58.178, - "acc@5": 80.624, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1)) -def squeezenet1_0( - *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any -) -> SqueezeNet: - weights = SqueezeNet1_0_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = SqueezeNet("1_0", **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1)) -def squeezenet1_1( - *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any -) -> SqueezeNet: - weights = SqueezeNet1_1_Weights.verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = SqueezeNet("1_1", **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model From b7916d329470f36eb4bbf4495023fb4fb0877a49 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 10:18:30 +0000 Subject: [PATCH 11/45] Porting vgg --- torchvision/models/vgg.py | 258 +++++++++++++++++++++++----- torchvision/prototype/models/vgg.py | 240 -------------------------- 2 files changed, 213 insertions(+), 285 deletions(-) delete mode 100644 torchvision/prototype/models/vgg.py diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 07639017a31..5007250f519 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -1,37 +1,38 @@ -from typing import Union, List, Dict, Any, cast +from functools import partial +from typing import Union, List, Dict, Any, Optional, cast import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param + __all__ = [ "VGG", + "VGG11_Weights", + "VGG11_BN_Weights", + "VGG13_Weights", + "VGG13_BN_Weights", + "VGG16_Weights", + "VGG16_BN_Weights", + "VGG19_Weights", + "VGG19_BN_Weights", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", - "vgg19_bn", "vgg19", + "vgg19_bn", ] -model_urls = { - "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", - "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", - "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", - "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", -} - - class VGG(nn.Module): def __init__( self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 @@ -95,107 +96,274 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ } -def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: - if pretrained: - kwargs["init_weights"] = False +def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +_COMMON_META = { + "task": "image_classification", + "architecture": "VGG", + "publication_year": 2014, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", +} + + +class VGG11_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg11-8a719046.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 132863336, + "acc@1": 69.020, + "acc@5": 88.628, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG11_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 132868840, + "acc@1": 70.370, + "acc@5": 89.810, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG13_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg13-19584684.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 133047848, + "acc@1": 69.928, + "acc@5": 89.246, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG13_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 133053736, + "acc@1": 71.586, + "acc@5": 90.374, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg16-397923af.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 138357544, + "acc@1": 71.592, + "acc@5": 90.382, + }, + ) + # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the + # same input standardization method as the paper. Only the `features` weights have proper values, those on the + # `classifier` module are filled with nans. + IMAGENET1K_FEATURES = Weights( + url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", + transforms=partial( + ImageClassificationEval, + crop_size=224, + mean=(0.48235, 0.45882, 0.40784), + std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), + ), + meta={ + **_COMMON_META, + "num_params": 138357544, + "categories": None, + "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", + "acc@1": float("nan"), + "acc@5": float("nan"), + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG16_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 138365992, + "acc@1": 73.360, + "acc@5": 91.516, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG19_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 143667240, + "acc@1": 72.376, + "acc@5": 90.876, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG19_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 143678248, + "acc@1": 74.218, + "acc@5": 91.842, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1)) +def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 11-layer model (configuration "A") from `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG11_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg11", "A", False, pretrained, progress, **kwargs) + weights = VGG11_Weights.verify(weights) + return _vgg("A", False, weights, progress, **kwargs) -def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1)) +def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 11-layer model (configuration "A") with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG11_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs) + weights = VGG11_BN_Weights.verify(weights) + + return _vgg("A", True, weights, progress, **kwargs) -def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1)) +def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 13-layer model (configuration "B") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG13_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg13", "B", False, pretrained, progress, **kwargs) + weights = VGG13_Weights.verify(weights) + return _vgg("B", False, weights, progress, **kwargs) -def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1)) +def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 13-layer model (configuration "B") with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG13_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs) + weights = VGG13_BN_Weights.verify(weights) + + return _vgg("B", True, weights, progress, **kwargs) -def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1)) +def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 16-layer model (configuration "D") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg16", "D", False, pretrained, progress, **kwargs) + weights = VGG16_Weights.verify(weights) + return _vgg("D", False, weights, progress, **kwargs) -def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1)) +def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 16-layer model (configuration "D") with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG16_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs) + weights = VGG16_BN_Weights.verify(weights) + + return _vgg("D", True, weights, progress, **kwargs) -def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: +@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1)) +def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 19-layer model (configuration "E") `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG19_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg19", "E", False, pretrained, progress, **kwargs) + weights = VGG19_Weights.verify(weights) + return _vgg("E", False, weights, progress, **kwargs) -def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: + +@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1)) +def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: r"""VGG 19-layer model (configuration 'E') with batch normalization `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_. The required minimum input size of the model is 32x32. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (VGG19_BN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs) + weights = VGG19_BN_Weights.verify(weights) + + return _vgg("E", True, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py deleted file mode 100644 index 90a6b06ad39..00000000000 --- a/torchvision/prototype/models/vgg.py +++ /dev/null @@ -1,240 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.vgg import VGG, make_layers, cfgs -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - - -__all__ = [ - "VGG", - "VGG11_Weights", - "VGG11_BN_Weights", - "VGG13_Weights", - "VGG13_BN_Weights", - "VGG16_Weights", - "VGG16_BN_Weights", - "VGG19_Weights", - "VGG19_BN_Weights", - "vgg11", - "vgg11_bn", - "vgg13", - "vgg13_bn", - "vgg16", - "vgg16_bn", - "vgg19", - "vgg19_bn", -] - - -def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "VGG", - "publication_year": 2014, - "size": (224, 224), - "min_size": (32, 32), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", -} - - -class VGG11_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg11-8a719046.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 132863336, - "acc@1": 69.020, - "acc@5": 88.628, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG11_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 132868840, - "acc@1": 70.370, - "acc@5": 89.810, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG13_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg13-19584684.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 133047848, - "acc@1": 69.928, - "acc@5": 89.246, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG13_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 133053736, - "acc@1": 71.586, - "acc@5": 90.374, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG16_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg16-397923af.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 138357544, - "acc@1": 71.592, - "acc@5": 90.382, - }, - ) - # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the - # same input standardization method as the paper. Only the `features` weights have proper values, those on the - # `classifier` module are filled with nans. - IMAGENET1K_FEATURES = Weights( - url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", - transforms=partial( - ImageClassificationEval, - crop_size=224, - mean=(0.48235, 0.45882, 0.40784), - std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), - ), - meta={ - **_COMMON_META, - "num_params": 138357544, - "categories": None, - "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", - "acc@1": float("nan"), - "acc@5": float("nan"), - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG16_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 138365992, - "acc@1": 73.360, - "acc@5": 91.516, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG19_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 143667240, - "acc@1": 72.376, - "acc@5": 90.876, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class VGG19_BN_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 143678248, - "acc@1": 74.218, - "acc@5": 91.842, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1)) -def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG11_Weights.verify(weights) - - return _vgg("A", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1)) -def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG11_BN_Weights.verify(weights) - - return _vgg("A", True, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1)) -def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG13_Weights.verify(weights) - - return _vgg("B", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1)) -def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG13_BN_Weights.verify(weights) - - return _vgg("B", True, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1)) -def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG16_Weights.verify(weights) - - return _vgg("D", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1)) -def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG16_BN_Weights.verify(weights) - - return _vgg("D", True, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1)) -def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG19_Weights.verify(weights) - - return _vgg("E", False, weights, progress, **kwargs) - - -@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1)) -def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: - weights = VGG19_BN_Weights.verify(weights) - - return _vgg("E", True, weights, progress, **kwargs) From a7fabb34d102fe233e729ac111e1ab5c37c2c7c7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 10:24:38 +0000 Subject: [PATCH 12/45] Porting vit --- torchvision/models/vision_transformer.py | 148 ++++++++++--- .../prototype/models/vision_transformer.py | 198 ------------------ 2 files changed, 117 insertions(+), 229 deletions(-) delete mode 100644 torchvision/prototype/models/vision_transformer.py diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 29f756ccbe5..f1d916f88fa 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -6,25 +6,27 @@ import torch import torch.nn as nn -from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import Conv2dNormActivation +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES +from ._utils import handle_legacy_interface, _ovewrite_named_param + + __all__ = [ "VisionTransformer", + "ViT_B_16_Weights", + "ViT_B_32_Weights", + "ViT_L_16_Weights", + "ViT_L_32_Weights", "vit_b_16", "vit_b_32", "vit_l_16", "vit_l_32", ] -model_urls = { - "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", - "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", -} - class ConvStemConfig(NamedTuple): out_channels: int @@ -274,18 +276,20 @@ def forward(self, x: torch.Tensor): def _vision_transformer( - arch: str, patch_size: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> VisionTransformer: image_size = kwargs.pop("image_size", 224) + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = VisionTransformer( image_size=image_size, patch_size=patch_size, @@ -296,98 +300,180 @@ def _vision_transformer( **kwargs, ) - if pretrained: - if arch not in model_urls: - raise ValueError(f"No checkpoint is available for model type '{arch}'!") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) + if weights: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +_COMMON_META = { + "task": "image_classification", + "architecture": "ViT", + "publication_year": 2020, + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class ViT_B_16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 86567656, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", + "acc@1": 81.072, + "acc@5": 95.318, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_B_32_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 88224232, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", + "acc@1": 75.912, + "acc@5": 92.466, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_L_16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242), + meta={ + **_COMMON_META, + "num_params": 304326632, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", + "acc@1": 79.662, + "acc@5": 94.638, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ViT_L_32_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 306535400, + "size": (224, 224), + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", + "acc@1": 76.972, + "acc@5": 93.07, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) +def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_B_16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_B_16_Weights.verify(weights) + return _vision_transformer( - arch="vit_b_16", patch_size=16, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1)) +def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_B_32_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_B_32_Weights.verify(weights) + return _vision_transformer( - arch="vit_b_32", patch_size=32, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1)) +def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_L_16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_L_16_Weights.verify(weights) + return _vision_transformer( - arch="vit_l_16", patch_size=16, num_layers=24, num_heads=16, hidden_dim=1024, mlp_dim=4096, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) -def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: +@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1)) +def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (ViT_L_32_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr """ + weights = ViT_L_32_Weights.verify(weights) + return _vision_transformer( - arch="vit_l_32", patch_size=32, num_layers=24, num_heads=16, hidden_dim=1024, mlp_dim=4096, - pretrained=pretrained, + weights=weights, progress=progress, **kwargs, ) diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py deleted file mode 100644 index 3fc34b72be5..00000000000 --- a/torchvision/prototype/models/vision_transformer.py +++ /dev/null @@ -1,198 +0,0 @@ -# References: -# https://github.com/google-research/vision_transformer -# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py - -from functools import partial -from typing import Any, Optional - -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from ._utils import handle_legacy_interface, _ovewrite_named_param - -__all__ = [ - "VisionTransformer", - "ViT_B_16_Weights", - "ViT_B_32_Weights", - "ViT_L_16_Weights", - "ViT_L_32_Weights", - "vit_b_16", - "vit_b_32", - "vit_l_16", - "vit_l_32", -] - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ViT", - "publication_year": 2020, - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class ViT_B_16_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 86567656, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16", - "acc@1": 81.072, - "acc@5": 95.318, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ViT_B_32_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 88224232, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32", - "acc@1": 75.912, - "acc@5": 92.466, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ViT_L_16_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242), - meta={ - **_COMMON_META, - "num_params": 304326632, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16", - "acc@1": 79.662, - "acc@5": 94.638, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -class ViT_L_32_Weights(WeightsEnum): - IMAGENET1K_V1 = Weights( - url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 306535400, - "size": (224, 224), - "min_size": (224, 224), - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32", - "acc@1": 76.972, - "acc@5": 93.07, - }, - ) - DEFAULT = IMAGENET1K_V1 - - -def _vision_transformer( - patch_size: int, - num_layers: int, - num_heads: int, - hidden_dim: int, - mlp_dim: int, - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> VisionTransformer: - image_size = kwargs.pop("image_size", 224) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = VisionTransformer( - image_size=image_size, - patch_size=patch_size, - num_layers=num_layers, - num_heads=num_heads, - hidden_dim=hidden_dim, - mlp_dim=mlp_dim, - **kwargs, - ) - - if weights: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1)) -def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_B_16_Weights.verify(weights) - - return _vision_transformer( - patch_size=16, - num_layers=12, - num_heads=12, - hidden_dim=768, - mlp_dim=3072, - weights=weights, - progress=progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1)) -def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_B_32_Weights.verify(weights) - - return _vision_transformer( - patch_size=32, - num_layers=12, - num_heads=12, - hidden_dim=768, - mlp_dim=3072, - weights=weights, - progress=progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1)) -def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_L_16_Weights.verify(weights) - - return _vision_transformer( - patch_size=16, - num_layers=24, - num_heads=16, - hidden_dim=1024, - mlp_dim=4096, - weights=weights, - progress=progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1)) -def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: - weights = ViT_L_32_Weights.verify(weights) - - return _vision_transformer( - patch_size=32, - num_layers=24, - num_heads=16, - hidden_dim=1024, - mlp_dim=4096, - weights=weights, - progress=progress, - **kwargs, - ) From 76f017c74b5252f67f76cb631514a1fe8b8a51f4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 10:25:43 +0000 Subject: [PATCH 13/45] Fix docstrings --- torchvision/models/densenet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 4da690a979f..17a5301cade 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -340,7 +340,7 @@ def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet121_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. @@ -357,7 +357,7 @@ def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet161_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. @@ -374,7 +374,7 @@ def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet169_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. @@ -391,7 +391,7 @@ def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool The required minimum input size of the model is 29x29. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + weights (DenseNet201_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. From eb5254d0fae7bbed004dfb13a31efddea723c3cd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 12:48:34 +0000 Subject: [PATCH 14/45] Fixing imports --- test/test_prototype_models.py | 2 +- torchvision/_utils.py | 14 +++- torchvision/models/_utils.py | 57 ++++++++++++++- torchvision/models/alexnet.py | 5 +- torchvision/models/convnext.py | 4 +- torchvision/models/densenet.py | 6 +- torchvision/models/efficientnet.py | 1 - torchvision/models/googlenet.py | 7 +- torchvision/models/inception.py | 6 +- torchvision/models/mnasnet.py | 2 +- torchvision/models/mobilenetv2.py | 3 +- torchvision/models/mobilenetv3.py | 1 - torchvision/models/quantization/googlenet.py | 9 ++- torchvision/models/quantization/inception.py | 8 ++- .../models/quantization/mobilenetv2.py | 8 ++- .../models/quantization/mobilenetv3.py | 8 ++- torchvision/models/quantization/resnet.py | 9 ++- .../models/quantization/shufflenetv2.py | 9 ++- torchvision/models/regnet.py | 4 -- torchvision/models/resnet.py | 2 +- torchvision/models/shufflenetv2.py | 3 - torchvision/models/squeezenet.py | 1 - torchvision/models/vgg.py | 1 - torchvision/models/vision_transformer.py | 1 - .../prototype/datasets/utils/_dataset.py | 3 +- torchvision/prototype/models/__init__.py | 5 ++ .../prototype/models/detection/faster_rcnn.py | 16 ++--- .../prototype/models/detection/fcos.py | 14 ++-- .../models/detection/keypoint_rcnn.py | 14 ++-- .../prototype/models/detection/mask_rcnn.py | 14 ++-- .../prototype/models/detection/retinanet.py | 14 ++-- torchvision/prototype/models/detection/ssd.py | 14 ++-- .../prototype/models/detection/ssdlite.py | 14 ++-- .../prototype/models/optical_flow/raft.py | 10 ++- .../models/quantization/googlenet.py | 14 ++-- .../models/quantization/inception.py | 14 ++-- .../models/quantization/mobilenetv2.py | 14 ++-- .../models/quantization/mobilenetv3.py | 14 ++-- .../prototype/models/quantization/resnet.py | 14 ++-- .../models/quantization/shufflenetv2.py | 14 ++-- .../models/segmentation/deeplabv3.py | 13 ++-- .../prototype/models/segmentation/fcn.py | 10 ++- .../prototype/models/segmentation/lraspp.py | 10 ++- torchvision/prototype/models/video/resnet.py | 12 ++-- .../prototype/transforms/_auto_augment.py | 5 +- torchvision/prototype/transforms/_geometry.py | 4 +- .../transforms/functional/_geometry.py | 3 +- torchvision/prototype/utils/_internal.py | 70 +------------------ 48 files changed, 245 insertions(+), 255 deletions(-) create mode 100644 torchvision/prototype/models/__init__.py diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 6f1abdfd466..f70dae39201 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -5,8 +5,8 @@ import test_models as TM import torch from common_utils import cpu_and_gpu, needs_cuda -from torchvision.prototype import models from torchvision.models._api import WeightsEnum, Weights +from torchvision.prototype import models from torchvision.prototype.models._utils import handle_legacy_interface run_if_test_with_prototype = pytest.mark.skipif( diff --git a/torchvision/_utils.py b/torchvision/_utils.py index da0eb923f75..8e8fe1b8a83 100644 --- a/torchvision/_utils.py +++ b/torchvision/_utils.py @@ -1,5 +1,5 @@ import enum -from typing import TypeVar, Type +from typing import Sequence, TypeVar, Type T = TypeVar("T", bound=enum.Enum) @@ -18,3 +18,15 @@ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc] class StrEnum(enum.Enum, metaclass=StrEnumMeta): pass + + +def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: + if not seq: + return "" + if len(seq) == 1: + return f"'{seq[0]}'" + + head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" + tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" + + return head + tail diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 1deff604b4a..9e3a81411a1 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,11 +1,12 @@ import functools +import inspect import warnings from collections import OrderedDict from typing import Any, Dict, Optional, TypeVar, Callable, Tuple, Union from torch import nn -from torchvision.prototype.utils._internal import kwonly_to_pos_or_kw +from .._utils import sequence_to_str from ._api import WeightsEnum @@ -88,6 +89,60 @@ def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> return new_v +D = TypeVar("D") + + +def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: + """Decorates a function that uses keyword only parameters to also allow them being passed as positionals. + + For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: + + .. code:: + + def old_fn(foo, bar, baz=None): + ... + + def new_fn(foo, *, bar, baz=None): + ... + + Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC + and at the same time warn the user of the deprecation, this decorator can be used: + + .. code:: + + @kwonly_to_pos_or_kw + def new_fn(foo, *, bar, baz=None): + ... + + new_fn("foo", "bar, "baz") + """ + params = inspect.signature(fn).parameters + + try: + keyword_only_start_idx = next( + idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY + ) + except StopIteration: + raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None + + keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> D: + args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] + if keyword_only_args: + keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) + warnings.warn( + f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " + f"parameter(s) is deprecated. Please use keyword parameter(s) instead." + ) + kwargs.update(keyword_only_kwargs) + + return fn(*args, **kwargs) + + return wrapper + + W = TypeVar("W", bound=WeightsEnum) M = TypeVar("M", bound=nn.Module) V = TypeVar("V") diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index a86f091dfa2..4df533000f9 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -4,13 +4,12 @@ import torch import torch.nn as nn +from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param -from ..transforms import ImageClassificationEval, InterpolationMode - __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] @@ -94,4 +93,4 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - return model \ No newline at end of file + return model diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 2717f9bf8e1..8d25e77eaa1 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -7,10 +7,8 @@ from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth -from ..utils import _log_api_usage_once - from ..transforms import ImageClassificationEval, InterpolationMode - +from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index 17a5301cade..b0de4529902 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -1,6 +1,6 @@ import re -from functools import partial from collections import OrderedDict +from functools import partial from typing import Any, List, Optional, Tuple import torch @@ -11,7 +11,6 @@ from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once - from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param @@ -277,6 +276,7 @@ def _densenet( "recipe": "https://github.com/pytorch/vision/pull/116", } + class DenseNet121_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", @@ -398,4 +398,4 @@ def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool """ weights = DenseNet201_Weights.verify(weights) - return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) \ No newline at end of file + return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index f76fa4f5477..9665c169bbf 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -12,7 +12,6 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once - from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index daad4ebb8df..2cac4a4fbbd 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -1,6 +1,6 @@ import warnings -from functools import partial from collections import namedtuple +from functools import partial from typing import Optional, Tuple, List, Callable, Any import torch @@ -8,9 +8,8 @@ import torch.nn.functional as F from torch import Tensor -from ..utils import _log_api_usage_once from ..transforms import ImageClassificationEval, InterpolationMode - +from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param @@ -333,4 +332,4 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" ) - return model \ No newline at end of file + return model diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index bdac024406c..1628542482b 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -7,10 +7,8 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..utils import _log_api_usage_once - - from ..transforms import ImageClassificationEval, InterpolationMode +from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param @@ -465,4 +463,4 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo model.aux_logits = False model.AuxLogits = None - return model \ No newline at end of file + return model diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 1634ea2e80b..b6ac15c68d2 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -1,5 +1,5 @@ -from functools import partial import warnings +from functools import partial from typing import Any, Dict, List, Optional import torch diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 7598338ebd9..acd94af4d10 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,5 +1,5 @@ -from functools import partial import warnings +from functools import partial from typing import Callable, Any, Optional, List import torch @@ -9,7 +9,6 @@ from ..ops.misc import Conv2dNormActivation from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once - from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 031b9d761f0..3a98456416d 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -8,7 +8,6 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once - from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 98d9382214f..9d0812e3b86 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor from torch.nn import functional as F -from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls +from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet from ..._internally_replaced_utils import load_state_dict_from_url from .utils import _fuse_modules, _replace_relu, quantize_model @@ -13,6 +13,13 @@ __all__ = ["QuantizableGoogLeNet", "googlenet"] + +model_urls = { + # GoogLeNet ported from TensorFlow + "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", +} + + quant_model_urls = { # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch "googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 27d021428b9..41b50fb6616 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -18,6 +18,12 @@ ] +model_urls = { + # Inception v3 ported from TensorFlow + "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", +} + + quant_model_urls = { # fp32 weights ported from TensorFlow, quantized in PyTorch "inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth" @@ -225,7 +231,7 @@ def inception_v3( model.AuxLogits = None model_url = quant_model_urls["inception_v3_google_" + backend] else: - model_url = inception_module.model_urls["inception_v3_google"] + model_url = model_urls["inception_v3_google"] state_dict = load_state_dict_from_url(model_url, progress=progress) diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 8cd9f16d13e..06116b8e084 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -3,7 +3,7 @@ from torch import Tensor from torch import nn from torch.ao.quantization import QuantStub, DeQuantStub -from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2 from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation @@ -12,6 +12,12 @@ __all__ = ["QuantizableMobileNetV2", "mobilenet_v2"] + +model_urls = { + "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", +} + + quant_model_urls = { "mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" } diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 4d7e2f7baad..3404a2b72b5 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -6,12 +6,18 @@ from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf +from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, _mobilenet_v3_conf from .utils import _fuse_modules, _replace_relu __all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"] + +model_urls = { + "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", +} + + quant_model_urls = { "mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", } diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index f55aa0e103c..874ee75ba9c 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from torch import Tensor -from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls +from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet from ..._internally_replaced_utils import load_state_dict_from_url from .utils import _fuse_modules, _replace_relu, quantize_model @@ -11,6 +11,13 @@ __all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"] +model_urls = { + "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", +} + + quant_model_urls = { "resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", "resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 9d25315ffa0..e196006a9c3 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -14,6 +14,13 @@ "shufflenet_v2_x1_0", ] + +model_urls = { + "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", +} + + quant_model_urls = { "shufflenetv2_x0.5_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", "shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", @@ -96,7 +103,7 @@ def _shufflenetv2( if quantize: model_url = quant_model_urls[arch + "_" + backend] else: - model_url = shufflenetv2.model_urls[arch] + model_url = model_urls[arch] state_dict = load_state_dict_from_url(model_url, progress=progress) diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 21727ecff71..1015c21b858 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -9,7 +9,6 @@ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once - from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible @@ -783,9 +782,6 @@ class RegNet_X_32GF_Weights(WeightsEnum): DEFAULT = IMAGENET1K_V2 - - - @handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1)) def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: """ diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 6391b0a0705..159749df006 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -5,8 +5,8 @@ import torch.nn as nn from torch import Tensor -from ..utils import _log_api_usage_once from ..transforms import ImageClassificationEval, InterpolationMode +from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 8aa20f7c3fc..e196dedcc3e 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -7,7 +7,6 @@ from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once - from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param @@ -26,7 +25,6 @@ ] - def channel_shuffle(x: Tensor, groups: int) -> Tensor: batchsize, num_channels, height, width = x.size() channels_per_group = num_channels // groups @@ -230,7 +228,6 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum): pass - @handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1)) def shufflenet_v2_x0_5( *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index 5d1d46f2856..d495b3148e5 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -7,7 +7,6 @@ from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once - from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 5007250f519..5393827b293 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -6,7 +6,6 @@ from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once - from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index f1d916f88fa..b11b7377ed1 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -9,7 +9,6 @@ from ..ops.misc import Conv2dNormActivation from ..transforms import ImageClassificationEval, InterpolationMode from ..utils import _log_api_usage_once - from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 5ee7c5ccc60..b5c6d7acb60 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -7,7 +7,8 @@ from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection from torch.utils.data import IterDataPipe -from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str +from torchvision._utils import sequence_to_str +from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion from .._home import use_sharded_dataset from ._internal import BUILTIN_DIR, _make_sharded_datapipe diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py new file mode 100644 index 00000000000..01b83708ce9 --- /dev/null +++ b/torchvision/prototype/models/__init__.py @@ -0,0 +1,5 @@ +from . import detection +from . import optical_flow +from . import quantization +from . import segmentation +from . import video diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 727da793605..5abc0eef1c4 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -1,10 +1,10 @@ from typing import Any, Optional, Union from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.faster_rcnn import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.faster_rcnn import ( _mobilenet_extractor, _resnet_fpn_extractor, _validate_trainable_layers, @@ -13,11 +13,9 @@ misc_nn_ops, overwrite_eps, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py index 153fa10c22c..930b26e46c8 100644 --- a/torchvision/prototype/models/detection/fcos.py +++ b/torchvision/prototype/models/detection/fcos.py @@ -1,20 +1,18 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.fcos import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.fcos import ( _resnet_fpn_extractor, _validate_trainable_layers, FCOS, LastLevelP6P7, misc_nn_ops, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index 41142da2a34..a7780cc9f63 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -1,20 +1,18 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.keypoint_rcnn import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.keypoint_rcnn import ( _resnet_fpn_extractor, _validate_trainable_layers, KeypointRCNN, misc_nn_ops, overwrite_eps, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index df598553686..d52ebe61be1 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -1,20 +1,18 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.mask_rcnn import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.mask_rcnn import ( _resnet_fpn_extractor, _validate_trainable_layers, MaskRCNN, misc_nn_ops, overwrite_eps, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index 7a021ec27c0..c4249118b70 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -1,10 +1,10 @@ from typing import Any, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.retinanet import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.retinanet import ( _resnet_fpn_extractor, _validate_trainable_layers, RetinaNet, @@ -12,10 +12,8 @@ misc_nn_ops, overwrite_eps, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, resnet50 +from torchvision.models.resnet import ResNet50_Weights, resnet50 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 5aae8c49055..a3c5b965deb 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -1,19 +1,17 @@ import warnings from typing import Any, Optional -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.ssd import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.ssd import ( _validate_trainable_layers, _vgg_extractor, DefaultBoxGenerator, SSD, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..vgg import VGG16_Weights, vgg16 +from torchvision.models.vgg import VGG16_Weights, vgg16 +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 7623b4ea861..d9f2ee58bc6 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -3,10 +3,10 @@ from typing import Any, Callable, Optional from torch import nn -from torchvision.prototype.transforms import ObjectDetectionEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.detection.ssdlite import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _COCO_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.detection.ssdlite import ( _mobilenet_extractor, _normal_init, _validate_trainable_layers, @@ -15,10 +15,8 @@ SSD, SSDLiteHead, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.transforms import ObjectDetectionEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 327764f5107..33e3243c2a0 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -2,14 +2,12 @@ from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.modules.instancenorm import InstanceNorm2d +from torchvision.models._api import Weights +from torchvision.models._api import WeightsEnum +from torchvision.models._utils import handle_legacy_interface from torchvision.models.optical_flow import RAFT from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock -from torchvision.prototype.transforms import OpticalFlowEval -from torchvision.transforms.functional import InterpolationMode - -from torchvision.models._api import WeightsEnum -from torchvision.models._api import Weights -from .._utils import handle_legacy_interface +from torchvision.transforms import OpticalFlowEval, InterpolationMode __all__ = ( diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index 12a21f973f6..38b43f5d6a7 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -2,18 +2,16 @@ from functools import partial from typing import Any, Optional, Union -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.googlenet import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.models.googlenet import GoogLeNet_Weights +from torchvision.models.quantization.googlenet import ( QuantizableGoogLeNet, _replace_relu, quantize_model, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..googlenet import GoogLeNet_Weights +from torchvision.transforms import ImageClassificationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index d3a5d59fd01..0659892613e 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -1,18 +1,16 @@ from functools import partial from typing import Any, Optional, Union -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.inception import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.models.inception import Inception_V3_Weights +from torchvision.models.quantization.inception import ( QuantizableInception3, _replace_relu, quantize_model, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..inception import Inception_V3_Weights +from torchvision.transforms import ImageClassificationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index 271c7440a33..f55b3c8eca5 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -1,19 +1,17 @@ from functools import partial from typing import Any, Optional, Union -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.mobilenetv2 import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights +from torchvision.models.quantization.mobilenetv2 import ( QuantizableInvertedResidual, QuantizableMobileNetV2, _replace_relu, quantize_model, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..mobilenetv2 import MobileNet_V2_Weights +from torchvision.transforms import ImageClassificationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index 482618b530c..b2a2422cabd 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -2,19 +2,17 @@ from typing import Any, List, Optional, Union import torch -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.mobilenetv3 import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf +from torchvision.models.quantization.mobilenetv3 import ( InvertedResidualConfig, QuantizableInvertedResidual, QuantizableMobileNetV3, _replace_relu, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf +from torchvision.transforms import ImageClassificationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 5bd94567d3a..289658aecda 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -1,20 +1,18 @@ from functools import partial from typing import Any, List, Optional, Type, Union -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.resnet import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.models.quantization.resnet import ( QuantizableBasicBlock, QuantizableBottleneck, QuantizableResNet, _replace_relu, quantize_model, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights +from torchvision.models.resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights +from torchvision.transforms import ImageClassificationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index ffc516db418..3b1f41affdb 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -1,18 +1,16 @@ from functools import partial from typing import Any, List, Optional, Union -from torchvision.prototype.transforms import ImageClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.quantization.shufflenetv2 import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.models.quantization.shufflenetv2 import ( QuantizableShuffleNetV2, _replace_relu, quantize_model, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights +from torchvision.models.shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights +from torchvision.transforms import ImageClassificationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index dfd067974a3..2c8d7f6ad84 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -1,16 +1,13 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import SemanticSegmentationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from torchvision.models._api import WeightsEnum, Weights from torchvision.models._meta import _VOC_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from ..resnet import resnet50, resnet101 -from ..resnet import ResNet50_Weights, ResNet101_Weights +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models.resnet import resnet50, resnet101, ResNet50_Weights, ResNet101_Weights +from torchvision.models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet +from torchvision.transforms import SemanticSegmentationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 13c197888e7..e7b12621940 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -1,14 +1,12 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import SemanticSegmentationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.segmentation.fcn import FCN, _fcn_resnet from torchvision.models._api import WeightsEnum, Weights from torchvision.models._meta import _VOC_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 +from torchvision.models.segmentation.fcn import FCN, _fcn_resnet +from torchvision.transforms import SemanticSegmentationEval, InterpolationMode __all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 75a827c3610..21c15373089 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -1,14 +1,12 @@ from functools import partial from typing import Any, Optional -from torchvision.prototype.transforms import SemanticSegmentationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from torchvision.models._api import WeightsEnum, Weights from torchvision.models._meta import _VOC_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_value_param -from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param +from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from torchvision.models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 +from torchvision.transforms import SemanticSegmentationEval, InterpolationMode __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 21bbb4d740d..0f4c0dd1dc9 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -2,10 +2,10 @@ from typing import Any, Callable, List, Optional, Sequence, Type, Union from torch import nn -from torchvision.prototype.transforms import VideoClassificationEval -from torchvision.transforms.functional import InterpolationMode - -from ....models.video.resnet import ( +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _KINETICS400_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.models.video.resnet import ( BasicBlock, BasicStem, Bottleneck, @@ -15,9 +15,7 @@ R2Plus1dStem, VideoResNet, ) -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _KINETICS400_CATEGORIES -from .._utils import handle_legacy_interface, _ovewrite_named_param +from torchvision.transforms import VideoClassificationEval, InterpolationMode __all__ = [ diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index c451feb9a32..7fc62423ab8 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -4,9 +4,10 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F +from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.utils._internal import query_recursively -from torchvision.transforms.functional import pil_to_tensor, to_pil_image +from torchvision.transforms.autoaugment import AutoAugmentPolicy +from torchvision.transforms.functional import pil_to_tensor, to_pil_image, InterpolationMode from ._utils import get_image_dimensions, is_simple_tensor diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index e04e9f819f3..5d340d6a13b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -6,8 +6,8 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F -from torchvision.transforms.functional import pil_to_tensor +from torchvision.prototype.transforms import Transform, functional as F +from torchvision.transforms.functional import pil_to_tensor, InterpolationMode from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6c9309749af..4bbb991358e 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -4,9 +4,8 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import InterpolationMode from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP -from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix +from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix, InterpolationMode from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 864bff9ce02..fe5284394cb 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -1,14 +1,11 @@ import collections.abc import difflib -import functools -import inspect import io import mmap import os import os.path import platform import textwrap -import warnings from typing import ( Any, BinaryIO, @@ -28,14 +25,14 @@ import numpy as np import torch +from torchvision._utils import sequence_to_str + __all__ = [ - "sequence_to_str", "add_suggestion", "FrozenMapping", "make_repr", "FrozenBunch", - "kwonly_to_pos_or_kw", "fromfile", "ReadOnlyTensorBuffer", "apply_recursively", @@ -43,18 +40,6 @@ ] -def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: - if not seq: - return "" - if len(seq) == 1: - return f"'{seq[0]}'" - - head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'" - tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'" - - return head + tail - - def add_suggestion( msg: str, *, @@ -151,57 +136,6 @@ def __repr__(self) -> str: return make_repr(type(self).__name__, self.items()) -def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: - """Decorates a function that uses keyword only parameters to also allow them being passed as positionals. - - For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: - - .. code:: - - def old_fn(foo, bar, baz=None): - ... - - def new_fn(foo, *, bar, baz=None): - ... - - Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC - and at the same time warn the user of the deprecation, this decorator can be used: - - .. code:: - - @kwonly_to_pos_or_kw - def new_fn(foo, *, bar, baz=None): - ... - - new_fn("foo", "bar, "baz") - """ - params = inspect.signature(fn).parameters - - try: - keyword_only_start_idx = next( - idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY - ) - except StopIteration: - raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None - - keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> D: - args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] - if keyword_only_args: - keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) - warnings.warn( - f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " - f"parameter(s) is deprecated. Please use keyword parameter(s) instead." - ) - kwargs.update(keyword_only_kwargs) - - return fn(*args, **kwargs) - - return wrapper - - def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable return bytearray(file.read(-1 if count == -1 else count * item_size)) From 56e0e0181e9c3b290ab5c0028629b7acc0d6046d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 12:51:16 +0000 Subject: [PATCH 15/45] Adding missing import --- torchvision/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 16495e8552e..f53245e2e35 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -18,3 +18,4 @@ from . import quantization from . import segmentation from . import video +from ._api import get_weight From 34f45d05567df8e52b7d591e5b7e0c5bf4aefc8b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 15:30:33 +0000 Subject: [PATCH 16/45] Fix mobilenet imports --- torchvision/models/mobilenet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index 4108305d3f5..0a270d14d3a 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -1,4 +1,6 @@ -from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all -from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all __all__ = mv2_all + mv3_all From 7cd616cd84741cea0bd03250386a815c8253bea4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 15:35:11 +0000 Subject: [PATCH 17/45] Fix tests --- test/test_prototype_models.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index f70dae39201..852b05111d7 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -4,10 +4,11 @@ import pytest import test_models as TM import torch +import torchvision from common_utils import cpu_and_gpu, needs_cuda from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._utils import handle_legacy_interface from torchvision.prototype import models -from torchvision.prototype.models._utils import handle_legacy_interface run_if_test_with_prototype = pytest.mark.skipif( os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1", @@ -88,7 +89,7 @@ def test_naming_conventions(model_fn): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models) + TM.get_models_from_module(torchvision.models) + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.segmentation) @@ -142,13 +143,6 @@ def test_schema_meta_validation(model_fn): assert not bad_names -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_classification_model(model_fn, dev): - TM.test_classification_model(model_fn, dev) - - @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection)) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -186,8 +180,7 @@ def test_raft(model_builder, scripted): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models) - + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) From 5ce348d6f4b3e64dc9ba3a7d93b176757536e243 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 15:55:52 +0000 Subject: [PATCH 18/45] Fix prototype tests --- test/test_prototype_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 852b05111d7..27a3598a5f3 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -55,8 +55,8 @@ def _build_model(fn, **kwargs): @pytest.mark.parametrize( "name, weight", [ - ("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1), - ("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2), + ("ResNet50_Weights.IMAGENET1K_V1", torchvision.models.ResNet50_Weights.IMAGENET1K_V1), + ("ResNet50_Weights.DEFAULT", torchvision.models.ResNet50_Weights.IMAGENET1K_V2), ( "ResNet50_QuantizedWeights.DEFAULT", models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2, @@ -68,7 +68,7 @@ def _build_model(fn, **kwargs): ], ) def test_get_weight(name, weight): - assert models.get_weight(name) == weight + assert torchvision.models.get_weight(name) == weight @pytest.mark.parametrize( From a8650a4ecb6110a9876dd01e056a732ece8275d3 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 16:47:07 +0000 Subject: [PATCH 19/45] Exclude get_weight from models on test --- test/test_backbone_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index ed9b52d0499..e07194a5685 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -13,7 +13,11 @@ def get_available_models(): # TODO add a registration mechanism to torchvision.models - return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + return [ + k + for k, v in models.__dict__.items() + if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight" + ] @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) From 08c5f9b946a9f847325e5167ee051ed582d7d3a6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 16:23:24 +0000 Subject: [PATCH 20/45] Fix init files --- torchvision/models/__init__.py | 15 +++++++-------- torchvision/models/quantization/__init__.py | 4 ++-- torchvision/models/quantization/mobilenet.py | 6 ++++-- torchvision/prototype/models/__init__.py | 1 - .../prototype/models/quantization/__init__.py | 5 ----- .../prototype/models/quantization/mobilenet.py | 6 ------ 6 files changed, 13 insertions(+), 24 deletions(-) delete mode 100644 torchvision/prototype/models/quantization/__init__.py delete mode 100644 torchvision/prototype/models/quantization/mobilenet.py diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index f53245e2e35..83e49908348 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -1,19 +1,18 @@ from .alexnet import * from .convnext import * -from .resnet import * -from .vgg import * -from .squeezenet import * -from .inception import * from .densenet import * +from .efficientnet import * from .googlenet import * -from .mobilenet import * +from .inception import * from .mnasnet import * -from .shufflenetv2 import * -from .efficientnet import * +from .mobilenet import * from .regnet import * +from .resnet import * +from .shufflenetv2 import * +from .squeezenet import * +from .vgg import * from .vision_transformer import * from . import detection -from . import feature_extraction from . import optical_flow from . import quantization from . import segmentation diff --git a/torchvision/models/quantization/__init__.py b/torchvision/models/quantization/__init__.py index deae997a219..da8bbba3567 100644 --- a/torchvision/models/quantization/__init__.py +++ b/torchvision/models/quantization/__init__.py @@ -1,5 +1,5 @@ -from .mobilenet import * -from .resnet import * from .googlenet import * from .inception import * +from .mobilenet import * +from .resnet import * from .shufflenetv2 import * diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py index 8f2c42db640..0a270d14d3a 100644 --- a/torchvision/models/quantization/mobilenet.py +++ b/torchvision/models/quantization/mobilenet.py @@ -1,4 +1,6 @@ -from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all -from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, __all__ as mv3_all +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all __all__ = mv2_all + mv3_all diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 01b83708ce9..5988c160aad 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1,5 +1,4 @@ from . import detection from . import optical_flow -from . import quantization from . import segmentation from . import video diff --git a/torchvision/prototype/models/quantization/__init__.py b/torchvision/prototype/models/quantization/__init__.py deleted file mode 100644 index da8bbba3567..00000000000 --- a/torchvision/prototype/models/quantization/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .googlenet import * -from .inception import * -from .mobilenet import * -from .resnet import * -from .shufflenetv2 import * diff --git a/torchvision/prototype/models/quantization/mobilenet.py b/torchvision/prototype/models/quantization/mobilenet.py deleted file mode 100644 index 0a270d14d3a..00000000000 --- a/torchvision/prototype/models/quantization/mobilenet.py +++ /dev/null @@ -1,6 +0,0 @@ -from .mobilenetv2 import * # noqa: F401, F403 -from .mobilenetv3 import * # noqa: F401, F403 -from .mobilenetv2 import __all__ as mv2_all -from .mobilenetv3 import __all__ as mv3_all - -__all__ = mv2_all + mv3_all From bc6d94ff155b259197ac8d631af3a7406fa022dc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 16:35:21 +0000 Subject: [PATCH 21/45] Porting googlenet --- torchvision/models/quantization/googlenet.py | 113 ++++++++++-------- .../models/quantization/googlenet.py | 92 -------------- 2 files changed, 65 insertions(+), 140 deletions(-) delete mode 100644 torchvision/prototype/models/quantization/googlenet.py diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 9d0812e3b86..5a5e9eca75c 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -1,29 +1,25 @@ import warnings -from typing import Any, Optional +from functools import partial +from typing import Any, Optional, Union import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F -from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param +from ..googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, GoogLeNet_Weights from .utils import _fuse_modules, _replace_relu, quantize_model -__all__ = ["QuantizableGoogLeNet", "googlenet"] - - -model_urls = { - # GoogLeNet ported from TensorFlow - "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", -} - - -quant_model_urls = { - # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch - "googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", -} +__all__ = [ + "QuantizableGoogLeNet", + "GoogLeNet_QuantizedWeights", + "googlenet", +] class QuantizableBasicConv2d(BasicConv2d): @@ -110,8 +106,41 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) +class GoogLeNet_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "GoogLeNet", + "publication_year": 2014, + "num_params": 6624904, + "size": (224, 224), + "min_size": (15, 15), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": GoogLeNet_Weights.IMAGENET1K_V1, + "acc@1": 69.826, + "acc@5": 89.404, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else GoogLeNet_Weights.IMAGENET1K_V1, + ) +) def googlenet( - pretrained: bool = False, + *, + weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -124,49 +153,37 @@ def googlenet( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional):The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model - aux_logits (bool): If True, adds two auxiliary branches that can improve training. - Default: *False* when pretrained is True otherwise *True* - transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" not in kwargs: - kwargs["aux_logits"] = False - if kwargs["aux_logits"]: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - kwargs["init_weights"] = False + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "init_weights", False) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") model = QuantizableGoogLeNet(**kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - model_url = quant_model_urls["googlenet_" + backend] - else: - model_url = model_urls["googlenet"] - - state_dict = load_state_dict_from_url(model_url, progress=progress) - - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) if not original_aux_logits: model.aux_logits = False model.aux1 = None # type: ignore[assignment] model.aux2 = None # type: ignore[assignment] + else: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" + ) + return model diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py deleted file mode 100644 index 38b43f5d6a7..00000000000 --- a/torchvision/prototype/models/quantization/googlenet.py +++ /dev/null @@ -1,92 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Optional, Union - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param -from torchvision.models.googlenet import GoogLeNet_Weights -from torchvision.models.quantization.googlenet import ( - QuantizableGoogLeNet, - _replace_relu, - quantize_model, -) -from torchvision.transforms import ImageClassificationEval, InterpolationMode - - -__all__ = [ - "QuantizableGoogLeNet", - "GoogLeNet_QuantizedWeights", - "googlenet", -] - - -class GoogLeNet_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "GoogLeNet", - "publication_year": 2014, - "num_params": 6624904, - "size": (224, 224), - "min_size": (15, 15), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": GoogLeNet_Weights.IMAGENET1K_V1, - "acc@1": 69.826, - "acc@5": 89.404, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else GoogLeNet_Weights.IMAGENET1K_V1, - ) -) -def googlenet( - *, - weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableGoogLeNet: - weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights) - - original_aux_logits = kwargs.get("aux_logits", False) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "init_weights", False) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableGoogLeNet(**kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not original_aux_logits: - model.aux_logits = False - model.aux1 = None # type: ignore[assignment] - model.aux2 = None # type: ignore[assignment] - else: - warnings.warn( - "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them" - ) - - return model From c106748cde30a2d1a4e65c8fcc31850f6107d4b4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 16:40:46 +0000 Subject: [PATCH 22/45] Porting inception --- torchvision/models/quantization/googlenet.py | 3 +- torchvision/models/quantization/inception.py | 111 ++++++++++-------- .../models/quantization/inception.py | 88 -------------- 3 files changed, 64 insertions(+), 138 deletions(-) delete mode 100644 torchvision/prototype/models/quantization/inception.py diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 5a5e9eca75c..cbc19977054 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -153,7 +153,8 @@ def googlenet( GPU inference is not yet supported Args: - pretrained (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional):The pretrained weights for the model + pretrained (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 41b50fb6616..d3deb489a64 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -1,35 +1,28 @@ import warnings -from typing import Any, List, Optional +from functools import partial +from typing import Any, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torchvision.models import inception as inception_module -from torchvision.models.inception import InceptionOutputs +from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param from .utils import _fuse_modules, _replace_relu, quantize_model __all__ = [ "QuantizableInception3", + "Inception_V3_QuantizedWeights", "inception_v3", ] -model_urls = { - # Inception v3 ported from TensorFlow - "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", -} - - -quant_model_urls = { - # fp32 weights ported from TensorFlow, quantized in PyTorch - "inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth" -} - - class QuantizableBasicConv2d(inception_module.BasicConv2d): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -179,8 +172,41 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) +class Inception_V3_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", + transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + meta={ + "task": "image_classification", + "architecture": "InceptionV3", + "publication_year": 2015, + "num_params": 27161264, + "size": (299, 299), + "min_size": (75, 75), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", + "unquantized": Inception_V3_Weights.IMAGENET1K_V1, + "acc@1": 77.176, + "acc@5": 93.354, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else Inception_V3_Weights.IMAGENET1K_V1, + ) +) def inception_v3( - pretrained: bool = False, + *, + weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -197,48 +223,35 @@ def inception_v3( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (Inception_V3_QuantizedWeights or Inception_V3_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model - aux_logits (bool): If True, add an auxiliary branch that can improve training. - Default: *True* - transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. """ - if pretrained: + weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights) + + original_aux_logits = kwargs.get("aux_logits", False) + if weights is not None: if "transform_input" not in kwargs: - kwargs["transform_input"] = True - if "aux_logits" in kwargs: - original_aux_logits = kwargs["aux_logits"] - kwargs["aux_logits"] = True - else: - original_aux_logits = False + _ovewrite_named_param(kwargs, "transform_input", True) + _ovewrite_named_param(kwargs, "aux_logits", True) + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") model = QuantizableInception3(**kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - model_url = quant_model_urls["inception_v3_google_" + backend] - else: - model_url = model_urls["inception_v3_google"] - - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + if quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None + model.load_state_dict(weights.get_state_dict(progress=progress)) + if not quantize and not original_aux_logits: + model.aux_logits = False + model.AuxLogits = None - if not quantize: - if not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None return model diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py deleted file mode 100644 index 0659892613e..00000000000 --- a/torchvision/prototype/models/quantization/inception.py +++ /dev/null @@ -1,88 +0,0 @@ -from functools import partial -from typing import Any, Optional, Union - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param -from torchvision.models.inception import Inception_V3_Weights -from torchvision.models.quantization.inception import ( - QuantizableInception3, - _replace_relu, - quantize_model, -) -from torchvision.transforms import ImageClassificationEval, InterpolationMode - - -__all__ = [ - "QuantizableInception3", - "Inception_V3_QuantizedWeights", - "inception_v3", -] - - -class Inception_V3_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), - meta={ - "task": "image_classification", - "architecture": "InceptionV3", - "publication_year": 2015, - "num_params": 27161264, - "size": (299, 299), - "min_size": (75, 75), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": Inception_V3_Weights.IMAGENET1K_V1, - "acc@1": 77.176, - "acc@5": 93.354, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else Inception_V3_Weights.IMAGENET1K_V1, - ) -) -def inception_v3( - *, - weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableInception3: - weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights) - - original_aux_logits = kwargs.get("aux_logits", False) - if weights is not None: - if "transform_input" not in kwargs: - _ovewrite_named_param(kwargs, "transform_input", True) - _ovewrite_named_param(kwargs, "aux_logits", True) - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableInception3(**kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - if quantize and not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - model.load_state_dict(weights.get_state_dict(progress=progress)) - if not quantize and not original_aux_logits: - model.aux_logits = False - model.AuxLogits = None - - return model From db33f53aa9f750983c7bdeae9812748efc1173da Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 16:53:57 +0000 Subject: [PATCH 23/45] porting mobilenetv2 --- .../models/quantization/mobilenetv2.py | 91 ++++++++++++------- .../models/quantization/mobilenetv2.py | 79 ---------------- 2 files changed, 60 insertions(+), 110 deletions(-) delete mode 100644 torchvision/prototype/models/quantization/mobilenetv2.py diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 06116b8e084..a260b13fd8d 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,26 +1,24 @@ -from typing import Any, Optional +from functools import partial +from typing import Any, Optional, Union from torch import Tensor from torch import nn from torch.ao.quantization import QuantStub, DeQuantStub -from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2 +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param from .utils import _fuse_modules, _replace_relu, quantize_model -__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"] - - -model_urls = { - "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", -} - - -quant_model_urls = { - "mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" -} +__all__ = [ + "QuantizableMobileNetV2", + "MobileNet_V2_QuantizedWeights", + "mobilenet_v2", +] class QuantizableInvertedResidual(InvertedResidual): @@ -66,8 +64,41 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) +class MobileNet_V2_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "MobileNetV2", + "publication_year": 2018, + "num_params": 3504872, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "qnnpack", + "quantization": "qat", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", + "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, + "acc@1": 71.658, + "acc@5": 90.150, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V2_Weights.IMAGENET1K_V1, + ) +) def mobilenet_v2( - pretrained: bool = False, + *, + weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -82,27 +113,25 @@ def mobilenet_v2( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. - progress (bool): If True, displays a progress bar of the download to stderr - quantize(bool): If True, returns a quantized model, else returns a float model + pretrained (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained + weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + quantize(bool): If True, returns a quantized model, else returns a float model """ + weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") + model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "qnnpack" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - model_url = quant_model_urls["mobilenet_v2_" + backend] - else: - model_url = model_urls["mobilenet_v2"] - state_dict = load_state_dict_from_url(model_url, progress=progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - model.load_state_dict(state_dict) return model diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py deleted file mode 100644 index f55b3c8eca5..00000000000 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ /dev/null @@ -1,79 +0,0 @@ -from functools import partial -from typing import Any, Optional, Union - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param -from torchvision.models.mobilenetv2 import MobileNet_V2_Weights -from torchvision.models.quantization.mobilenetv2 import ( - QuantizableInvertedResidual, - QuantizableMobileNetV2, - _replace_relu, - quantize_model, -) -from torchvision.transforms import ImageClassificationEval, InterpolationMode - - -__all__ = [ - "QuantizableMobileNetV2", - "MobileNet_V2_QuantizedWeights", - "mobilenet_v2", -] - - -class MobileNet_V2_QuantizedWeights(WeightsEnum): - IMAGENET1K_QNNPACK_V1 = Weights( - url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "MobileNetV2", - "publication_year": 2018, - "num_params": 3504872, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "qnnpack", - "quantization": "qat", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", - "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, - "acc@1": 71.658, - "acc@5": 90.150, - }, - ) - DEFAULT = IMAGENET1K_QNNPACK_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1 - if kwargs.get("quantize", False) - else MobileNet_V2_Weights.IMAGENET1K_V1, - ) -) -def mobilenet_v2( - *, - weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableMobileNetV2: - weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights) - - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "qnnpack") - - model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model From f34279754b6aafd2fb219cd946f16ce95e4f2634 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 17:03:09 +0000 Subject: [PATCH 24/45] porting mobilenetv3 --- .../models/quantization/mobilenetv3.py | 100 +++++++++++------- .../models/quantization/mobilenetv3.py | 99 ----------------- 2 files changed, 63 insertions(+), 136 deletions(-) delete mode 100644 torchvision/prototype/models/quantization/mobilenetv3.py diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 3404a2b72b5..eafff9ec8f3 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,26 +1,24 @@ -from typing import Any, List, Optional +from functools import partial +from typing import Any, List, Optional, Union import torch from torch import nn, Tensor from torch.ao.quantization import QuantStub, DeQuantStub -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, _mobilenet_v3_conf +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param +from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, _mobilenet_v3_conf, MobileNet_V3_Large_Weights from .utils import _fuse_modules, _replace_relu -__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"] - - -model_urls = { - "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", -} - - -quant_model_urls = { - "mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", -} +__all__ = [ + "QuantizableMobileNetV3", + "MobileNet_V3_Large_QuantizedWeights", + "mobilenet_v3_large", +] class QuantizableSqueezeExcitation(SqueezeExcitation): @@ -118,47 +116,73 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: m.fuse_model(is_qat) -def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None: - if model_url is None: - raise ValueError(f"No checkpoint is available for {arch}") - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) - - def _mobilenet_v3_model( - arch: str, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, ) -> QuantizableMobileNetV3: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "qnnpack") model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) _replace_relu(model) if quantize: - backend = "qnnpack" - model.fuse_model(is_qat=True) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) torch.ao.quantization.prepare_qat(model, inplace=True) - if pretrained: - _load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if quantize: torch.ao.quantization.convert(model, inplace=True) model.eval() - else: - if pretrained: - _load_weights(arch, model, model_urls.get(arch, None), progress) return model +class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): + IMAGENET1K_QNNPACK_V1 = Weights( + url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + "task": "image_classification", + "architecture": "MobileNetV3", + "publication_year": 2019, + "num_params": 5483032, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "qnnpack", + "quantization": "qat", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", + "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, + "acc@1": 73.004, + "acc@5": 90.858, + }, + ) + DEFAULT = IMAGENET1K_QNNPACK_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1 + if kwargs.get("quantize", False) + else MobileNet_V3_Large_Weights.IMAGENET1K_V1, + ) +) def mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -172,10 +196,12 @@ def mobilenet_v3_large( GPU inference is not yet supported Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. - progress (bool): If True, displays a progress bar of the download to stderr - quantize (bool): If True, returns a quantized model, else returns a float model + pretrained (MobileNet_V3_Large_QuantizedWeights or MobileNet_V3_Large_Weights, optional): The pretrained + weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + quantize (bool): If True, returns a quantized model, else returns a float model """ - arch = "mobilenet_v3_large" - inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs) + weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py deleted file mode 100644 index b2a2422cabd..00000000000 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ /dev/null @@ -1,99 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Union - -import torch -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param -from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf -from torchvision.models.quantization.mobilenetv3 import ( - InvertedResidualConfig, - QuantizableInvertedResidual, - QuantizableMobileNetV3, - _replace_relu, -) -from torchvision.transforms import ImageClassificationEval, InterpolationMode - - -__all__ = [ - "QuantizableMobileNetV3", - "MobileNet_V3_Large_QuantizedWeights", - "mobilenet_v3_large", -] - - -def _mobilenet_v3_model( - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - weights: Optional[WeightsEnum], - progress: bool, - quantize: bool, - **kwargs: Any, -) -> QuantizableMobileNetV3: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "qnnpack") - - model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) - _replace_relu(model) - - if quantize: - model.fuse_model(is_qat=True) - model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend) - torch.ao.quantization.prepare_qat(model, inplace=True) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - if quantize: - torch.ao.quantization.convert(model, inplace=True) - model.eval() - - return model - - -class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): - IMAGENET1K_QNNPACK_V1 = Weights( - url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - "task": "image_classification", - "architecture": "MobileNetV3", - "publication_year": 2019, - "num_params": 5483032, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "qnnpack", - "quantization": "qat", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", - "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, - "acc@1": 73.004, - "acc@5": 90.858, - }, - ) - DEFAULT = IMAGENET1K_QNNPACK_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1 - if kwargs.get("quantize", False) - else MobileNet_V3_Large_Weights.IMAGENET1K_V1, - ) -) -def mobilenet_v3_large( - *, - weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableMobileNetV3: - weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights) - - inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) - return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) From 3db3e77972c9e6534ec60b770f9934ea9210e2ad Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 18:27:01 +0000 Subject: [PATCH 25/45] porting resnet --- torchvision/models/quantization/resnet.py | 200 +++++++++++++---- .../prototype/models/quantization/resnet.py | 202 ------------------ 2 files changed, 158 insertions(+), 244 deletions(-) delete mode 100644 torchvision/prototype/models/quantization/resnet.py diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 874ee75ba9c..2894ee278e4 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -1,28 +1,27 @@ +from functools import partial from typing import Any, Type, Union, List, Optional import torch import torch.nn as nn from torch import Tensor -from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet +from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param from .utils import _fuse_modules, _replace_relu, quantize_model -__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"] - -model_urls = { - "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", - "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", - "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", -} - - -quant_model_urls = { - "resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - "resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - "resnext101_32x8d_fbgemm": "https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", -} +__all__ = [ + "QuantizableResNet", + "ResNet18_QuantizedWeights", + "ResNet50_QuantizedWeights", + "ResNeXt101_32X8D_QuantizedWeights", + "resnet18", + "resnet50", + "resnext101_32x8d", +] class QuantizableBasicBlock(BasicBlock): @@ -116,38 +115,129 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: def _resnet( - arch: str, block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], layers: List[int], - pretrained: bool, + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, ) -> QuantizableResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") model = QuantizableResNet(block, layers, **kwargs) _replace_relu(model) if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - if quantize: - model_url = quant_model_urls[arch + "_" + backend] - else: - model_url = model_urls[arch] - state_dict = load_state_dict_from_url(model_url, progress=progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - model.load_state_dict(state_dict) return model +_COMMON_META = { + "task": "image_classification", + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", +} + +class ResNet18_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 11689512, + "unquantized": ResNet18_Weights.IMAGENET1K_V1, + "acc@1": 69.494, + "acc@5": 88.882, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ResNet50_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V1, + "acc@1": 75.920, + "acc@5": 92.814, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNet", + "publication_year": 2015, + "num_params": 25557032, + "unquantized": ResNet50_Weights.IMAGENET1K_V2, + "acc@1": 80.282, + "acc@5": 94.976, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + "acc@1": 78.986, + "acc@5": 94.480, + }, + ) + IMAGENET1K_FBGEMM_V2 = Weights( + url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", + transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "architecture": "ResNeXt", + "publication_year": 2016, + "num_params": 88791336, + "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2, + "acc@1": 82.574, + "acc@5": 96.132, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V2 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet18_Weights.IMAGENET1K_V1, + ) +) def resnet18( - pretrained: bool = False, + *, + weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -156,33 +246,56 @@ def resnet18( `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (ResNet18_QuantizedWeights or ResNet18_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet("resnet18", QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs) + weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights) + + return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNet50_Weights.IMAGENET1K_V1, + ) +) def resnet50( - pretrained: bool = False, + *, + weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, ) -> QuantizableResNet: - r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (ResNet50_QuantizedWeights or ResNet50_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet("resnet50", QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) + weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights) + + return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ResNeXt101_32X8D_Weights.IMAGENET1K_V1, + ) +) def resnext101_32x8d( - pretrained: bool = False, + *, + weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -191,10 +304,13 @@ def resnext101_32x8d( `"Aggregated Residual Transformation for Deep Neural Networks" `_ Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (ResNeXt101_32X8D_QuantizedWeights or ResNeXt101_32X8D_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - kwargs["groups"] = 32 - kwargs["width_per_group"] = 8 - return _resnet("resnext101_32x8d", QuantizableBottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs) + weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights) + + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py deleted file mode 100644 index 289658aecda..00000000000 --- a/torchvision/prototype/models/quantization/resnet.py +++ /dev/null @@ -1,202 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Type, Union - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param -from torchvision.models.quantization.resnet import ( - QuantizableBasicBlock, - QuantizableBottleneck, - QuantizableResNet, - _replace_relu, - quantize_model, -) -from torchvision.models.resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights -from torchvision.transforms import ImageClassificationEval, InterpolationMode - - -__all__ = [ - "QuantizableResNet", - "ResNet18_QuantizedWeights", - "ResNet50_QuantizedWeights", - "ResNeXt101_32X8D_QuantizedWeights", - "resnet18", - "resnet50", - "resnext101_32x8d", -] - - -def _resnet( - block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], - layers: List[int], - weights: Optional[WeightsEnum], - progress: bool, - quantize: bool, - **kwargs: Any, -) -> QuantizableResNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableResNet(block, layers, **kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", -} - - -class ResNet18_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 11689512, - "unquantized": ResNet18_Weights.IMAGENET1K_V1, - "acc@1": 69.494, - "acc@5": 88.882, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -class ResNet50_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "unquantized": ResNet50_Weights.IMAGENET1K_V1, - "acc@1": 75.920, - "acc@5": 92.814, - }, - ) - IMAGENET1K_FBGEMM_V2 = Weights( - url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNet", - "publication_year": 2015, - "num_params": 25557032, - "unquantized": ResNet50_Weights.IMAGENET1K_V2, - "acc@1": 80.282, - "acc@5": 94.976, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V2 - - -class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V1, - "acc@1": 78.986, - "acc@5": 94.480, - }, - ) - IMAGENET1K_FBGEMM_V2 = Weights( - url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), - meta={ - **_COMMON_META, - "architecture": "ResNeXt", - "publication_year": 2016, - "num_params": 88791336, - "unquantized": ResNeXt101_32X8D_Weights.IMAGENET1K_V2, - "acc@1": 82.574, - "acc@5": 96.132, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V2 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ResNet18_Weights.IMAGENET1K_V1, - ) -) -def resnet18( - *, - weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableResNet: - weights = (ResNet18_QuantizedWeights if quantize else ResNet18_Weights).verify(weights) - - return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ResNet50_Weights.IMAGENET1K_V1, - ) -) -def resnet50( - *, - weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableResNet: - weights = (ResNet50_QuantizedWeights if quantize else ResNet50_Weights).verify(weights) - - return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ResNeXt101_32X8D_Weights.IMAGENET1K_V1, - ) -) -def resnext101_32x8d( - *, - weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableResNet: - weights = (ResNeXt101_32X8D_QuantizedWeights if quantize else ResNeXt101_32X8D_Weights).verify(weights) - - _ovewrite_named_param(kwargs, "groups", 32) - _ovewrite_named_param(kwargs, "width_per_group", 8) - return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs) From 829c5c1439bc726f3476c11056c5e0b4e6e7bfa4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 18:30:47 +0000 Subject: [PATCH 26/45] porting shufflenetv2 --- .../models/quantization/shufflenetv2.py | 132 ++++++++++++----- .../models/quantization/shufflenetv2.py | 134 ------------------ 2 files changed, 94 insertions(+), 172 deletions(-) delete mode 100644 torchvision/prototype/models/quantization/shufflenetv2.py diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index e196006a9c3..5a12aa23207 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -1,32 +1,28 @@ -from typing import Any, Optional +from functools import partial +from typing import Any, List, Optional, Union import torch import torch.nn as nn from torch import Tensor from torchvision.models import shufflenetv2 -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import ImageClassificationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _IMAGENET_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param +from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights from .utils import _fuse_modules, _replace_relu, quantize_model + __all__ = [ "QuantizableShuffleNetV2", + "ShuffleNet_V2_X0_5_QuantizedWeights", + "ShuffleNet_V2_X1_0_QuantizedWeights", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", ] -model_urls = { - "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", -} - - -quant_model_urls = { - "shufflenetv2_x0.5_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - "shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", -} - - class QuantizableInvertedResidual(shufflenetv2.InvertedResidual): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -80,39 +76,86 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: def _shufflenetv2( - arch: str, - pretrained: bool, + stages_repeats: List[int], + stages_out_channels: List[int], + *, + weights: Optional[WeightsEnum], progress: bool, quantize: bool, - *args: Any, **kwargs: Any, ) -> QuantizableShuffleNetV2: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if "backend" in weights.meta: + _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) + backend = kwargs.pop("backend", "fbgemm") - model = QuantizableShuffleNetV2(*args, **kwargs) + model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) _replace_relu(model) - if quantize: - # TODO use pretrained as a string to specify the backend - backend = "fbgemm" quantize_model(model, backend) - else: - assert pretrained in [True, False] - - if pretrained: - model_url: Optional[str] = None - if quantize: - model_url = quant_model_urls[arch + "_" + backend] - else: - model_url = model_urls[arch] - state_dict = load_state_dict_from_url(model_url, progress=progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - model.load_state_dict(state_dict) return model +_COMMON_META = { + "task": "image_classification", + "architecture": "ShuffleNetV2", + "publication_year": 2018, + "size": (224, 224), + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "fbgemm", + "quantization": "ptq", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", +} + + +class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 1366792, + "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + "acc@1": 57.972, + "acc@5": 79.780, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): + IMAGENET1K_FBGEMM_V1 = Weights( + url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", + transforms=partial(ImageClassificationEval, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 2278604, + "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + "acc@1": 68.360, + "acc@5": 87.582, + }, + ) + DEFAULT = IMAGENET1K_FBGEMM_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, + ) +) def shufflenet_v2_x0_5( - pretrained: bool = False, + *, + weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -123,17 +166,28 @@ def shufflenet_v2_x0_5( `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (ShuffleNet_V2_X0_5_QuantizedWeights or ShuffleNet_V2_X0_5_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ + weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights) return _shufflenetv2( - "shufflenetv2_x0.5", pretrained, progress, quantize, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs + [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs ) +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 + if kwargs.get("quantize", False) + else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, + ) +) def shufflenet_v2_x1_0( - pretrained: bool = False, + *, + weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -144,10 +198,12 @@ def shufflenet_v2_x1_0( `_. Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet + pretrained (ShuffleNet_V2_X1_0_QuantizedWeights or ShuffleNet_V2_X1_0_Weights, optional): The pretrained + weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ + weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights) return _shufflenetv2( - "shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs + [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs ) diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py deleted file mode 100644 index 3b1f41affdb..00000000000 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ /dev/null @@ -1,134 +0,0 @@ -from functools import partial -from typing import Any, List, Optional, Union - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _IMAGENET_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param -from torchvision.models.quantization.shufflenetv2 import ( - QuantizableShuffleNetV2, - _replace_relu, - quantize_model, -) -from torchvision.models.shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights -from torchvision.transforms import ImageClassificationEval, InterpolationMode - - -__all__ = [ - "QuantizableShuffleNetV2", - "ShuffleNet_V2_X0_5_QuantizedWeights", - "ShuffleNet_V2_X1_0_QuantizedWeights", - "shufflenet_v2_x0_5", - "shufflenet_v2_x1_0", -] - - -def _shufflenetv2( - stages_repeats: List[int], - stages_out_channels: List[int], - *, - weights: Optional[WeightsEnum], - progress: bool, - quantize: bool, - **kwargs: Any, -) -> QuantizableShuffleNetV2: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - if "backend" in weights.meta: - _ovewrite_named_param(kwargs, "backend", weights.meta["backend"]) - backend = kwargs.pop("backend", "fbgemm") - - model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs) - _replace_relu(model) - if quantize: - quantize_model(model, backend) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "image_classification", - "architecture": "ShuffleNetV2", - "publication_year": 2018, - "size": (224, 224), - "min_size": (1, 1), - "categories": _IMAGENET_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "backend": "fbgemm", - "quantization": "ptq", - "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", -} - - -class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 1366792, - "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, - "acc@1": 57.972, - "acc@5": 79.780, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): - IMAGENET1K_FBGEMM_V1 = Weights( - url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", - transforms=partial(ImageClassificationEval, crop_size=224), - meta={ - **_COMMON_META, - "num_params": 2278604, - "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, - "acc@1": 68.360, - "acc@5": 87.582, - }, - ) - DEFAULT = IMAGENET1K_FBGEMM_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1, - ) -) -def shufflenet_v2_x0_5( - *, - weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableShuffleNetV2: - weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights) - return _shufflenetv2( - [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs - ) - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1 - if kwargs.get("quantize", False) - else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1, - ) -) -def shufflenet_v2_x1_0( - *, - weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, - progress: bool = True, - quantize: bool = False, - **kwargs: Any, -) -> QuantizableShuffleNetV2: - weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights) - return _shufflenetv2( - [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs - ) From 469eadd06d6214e1656e898ede08cf95f8a2f132 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 18:39:33 +0000 Subject: [PATCH 27/45] Fix test and linter --- test/test_prototype_models.py | 17 +++++------------ torchvision/models/quantization/mobilenetv3.py | 8 +++++++- torchvision/models/quantization/resnet.py | 10 +++++++++- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 27a3598a5f3..4dde47d40c3 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -59,11 +59,11 @@ def _build_model(fn, **kwargs): ("ResNet50_Weights.DEFAULT", torchvision.models.ResNet50_Weights.IMAGENET1K_V2), ( "ResNet50_QuantizedWeights.DEFAULT", - models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2, + torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2, ), ( "ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1", - models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1, + torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1, ), ], ) @@ -73,9 +73,9 @@ def test_get_weight(name, weight): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models) + TM.get_models_from_module(torchvision.models) + TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(torchvision.models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), @@ -91,7 +91,7 @@ def test_naming_conventions(model_fn): "model_fn", TM.get_models_from_module(torchvision.models) + TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(torchvision.models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), @@ -150,12 +150,6 @@ def test_detection_model(model_fn, dev): TM.test_detection_model(model_fn, dev) -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization)) -@run_if_test_with_prototype -def test_quantized_classification_model(model_fn): - TM.test_quantized_classification_model(model_fn) - - @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation)) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -181,7 +175,6 @@ def test_raft(model_builder, scripted): @pytest.mark.parametrize( "model_fn", TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index eafff9ec8f3..4069683970c 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -10,7 +10,13 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param -from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, _mobilenet_v3_conf, MobileNet_V3_Large_Weights +from ..mobilenetv3 import ( + InvertedResidual, + InvertedResidualConfig, + MobileNetV3, + _mobilenet_v3_conf, + MobileNet_V3_Large_Weights, +) from .utils import _fuse_modules, _replace_relu diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 2894ee278e4..7da8ef05c66 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -4,7 +4,14 @@ import torch import torch.nn as nn from torch import Tensor -from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights +from torchvision.models.resnet import ( + Bottleneck, + BasicBlock, + ResNet, + ResNet18_Weights, + ResNet50_Weights, + ResNeXt101_32X8D_Weights, +) from ...transforms import ImageClassificationEval, InterpolationMode from .._api import WeightsEnum, Weights @@ -138,6 +145,7 @@ def _resnet( return model + _COMMON_META = { "task": "image_classification", "size": (224, 224), From 66d7642962232f0689bd46e234e5e43f5382d2b2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 14 Mar 2022 18:53:03 +0000 Subject: [PATCH 28/45] Fixing docs. --- torchvision/models/quantization/googlenet.py | 2 +- torchvision/models/quantization/inception.py | 2 +- torchvision/models/quantization/mobilenetv2.py | 2 +- torchvision/models/quantization/mobilenetv3.py | 2 +- torchvision/models/quantization/resnet.py | 6 +++--- torchvision/models/quantization/shufflenetv2.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index cbc19977054..befc2299c06 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -153,7 +153,7 @@ def googlenet( GPU inference is not yet supported Args: - pretrained (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained + weights (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index d3deb489a64..697d99d4027 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -223,7 +223,7 @@ def inception_v3( GPU inference is not yet supported Args: - pretrained (Inception_V3_QuantizedWeights or Inception_V3_Weights, optional): The pretrained + weights (Inception_V3_QuantizedWeights or Inception_V3_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index a260b13fd8d..40f5cb544fd 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -113,7 +113,7 @@ def mobilenet_v2( GPU inference is not yet supported Args: - pretrained (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained + weights (GoogLeNet_QuantizedWeights or GoogLeNet_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize(bool): If True, returns a quantized model, else returns a float model diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 4069683970c..4b79b7f26ae 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -202,7 +202,7 @@ def mobilenet_v3_large( GPU inference is not yet supported Args: - pretrained (MobileNet_V3_Large_QuantizedWeights or MobileNet_V3_Large_Weights, optional): The pretrained + weights (MobileNet_V3_Large_QuantizedWeights or MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, returns a quantized model, else returns a float model diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 7da8ef05c66..666b1b23163 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -254,7 +254,7 @@ def resnet18( `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (ResNet18_QuantizedWeights or ResNet18_Weights, optional): The pretrained + weights (ResNet18_QuantizedWeights or ResNet18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model @@ -283,7 +283,7 @@ def resnet50( `"Deep Residual Learning for Image Recognition" `_ Args: - pretrained (ResNet50_QuantizedWeights or ResNet50_Weights, optional): The pretrained + weights (ResNet50_QuantizedWeights or ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model @@ -312,7 +312,7 @@ def resnext101_32x8d( `"Aggregated Residual Transformation for Deep Neural Networks" `_ Args: - pretrained (ResNeXt101_32X8D_QuantizedWeights or ResNeXt101_32X8D_Weights, optional): The pretrained + weights (ResNeXt101_32X8D_QuantizedWeights or ResNeXt101_32X8D_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 5a12aa23207..c5bfe698636 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -166,7 +166,7 @@ def shufflenet_v2_x0_5( `_. Args: - pretrained (ShuffleNet_V2_X0_5_QuantizedWeights or ShuffleNet_V2_X0_5_Weights, optional): The pretrained + weights (ShuffleNet_V2_X0_5_QuantizedWeights or ShuffleNet_V2_X0_5_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model @@ -198,7 +198,7 @@ def shufflenet_v2_x1_0( `_. Args: - pretrained (ShuffleNet_V2_X1_0_QuantizedWeights or ShuffleNet_V2_X1_0_Weights, optional): The pretrained + weights (ShuffleNet_V2_X1_0_QuantizedWeights or ShuffleNet_V2_X1_0_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model From a39e60ed2f3e27a14ea05afef1bbbc5554fc80c2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 15 Mar 2022 12:59:59 +0000 Subject: [PATCH 29/45] Porting Detection models (#5617) * fix inits * fix docs * Port faster_rcnn * Port fcos * Port keypoint_rcnn * Port mask_rcnn * Port retinanet * Port ssd * Port ssdlite * Fix linter * Fixing tests * Fixing tests * Fixing vgg test --- test/test_prototype_models.py | 14 +- torchvision/models/detection/__init__.py | 4 +- .../models/detection/backbone_utils.py | 2 +- torchvision/models/detection/faster_rcnn.py | 221 ++++++++++++----- torchvision/models/detection/fcos.py | 79 ++++-- torchvision/models/detection/keypoint_rcnn.py | 122 +++++++--- torchvision/models/detection/mask_rcnn.py | 85 +++++-- torchvision/models/detection/retinanet.py | 83 +++++-- torchvision/models/detection/roi_heads.py | 3 +- torchvision/models/detection/ssd.py | 99 ++++---- torchvision/models/detection/ssdlite.py | 92 ++++--- torchvision/models/vgg.py | 3 +- torchvision/prototype/models/__init__.py | 1 - .../prototype/models/detection/__init__.py | 7 - .../prototype/models/detection/faster_rcnn.py | 226 ------------------ .../prototype/models/detection/fcos.py | 78 ------ .../models/detection/keypoint_rcnn.py | 106 -------- .../prototype/models/detection/mask_rcnn.py | 79 ------ .../prototype/models/detection/retinanet.py | 82 ------- torchvision/prototype/models/detection/ssd.py | 91 ------- .../prototype/models/detection/ssdlite.py | 124 ---------- 21 files changed, 550 insertions(+), 1051 deletions(-) delete mode 100644 torchvision/prototype/models/detection/__init__.py delete mode 100644 torchvision/prototype/models/detection/faster_rcnn.py delete mode 100644 torchvision/prototype/models/detection/fcos.py delete mode 100644 torchvision/prototype/models/detection/keypoint_rcnn.py delete mode 100644 torchvision/prototype/models/detection/mask_rcnn.py delete mode 100644 torchvision/prototype/models/detection/retinanet.py delete mode 100644 torchvision/prototype/models/detection/ssd.py delete mode 100644 torchvision/prototype/models/detection/ssdlite.py diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 4dde47d40c3..06baea35fa8 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -74,7 +74,7 @@ def test_get_weight(name, weight): @pytest.mark.parametrize( "model_fn", TM.get_models_from_module(torchvision.models) - + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(torchvision.models.detection) + TM.get_models_from_module(torchvision.models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) @@ -90,7 +90,7 @@ def test_naming_conventions(model_fn): @pytest.mark.parametrize( "model_fn", TM.get_models_from_module(torchvision.models) - + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(torchvision.models.detection) + TM.get_models_from_module(torchvision.models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) @@ -143,13 +143,6 @@ def test_schema_meta_validation(model_fn): assert not bad_names -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_detection_model(model_fn, dev): - TM.test_detection_model(model_fn, dev) - - @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation)) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -174,8 +167,7 @@ def test_raft(model_builder, scripted): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), ) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index be46f950a61..4146651c737 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -1,7 +1,7 @@ from .faster_rcnn import * -from .mask_rcnn import * +from .fcos import * from .keypoint_rcnn import * +from .mask_rcnn import * from .retinanet import * from .ssd import * from .ssdlite import * -from .fcos import * diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 5ac5f179479..cac96b61f64 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -88,7 +88,7 @@ def resnet_fpn_backbone( pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet norm_layer (callable): it is recommended to use the default value. For details visit: (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) - trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block. + trainable_layers (int): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``. By default all layers are returned. diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 790740fe9c5..18872adc029 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -1,11 +1,16 @@ +from typing import Any, Optional, Union + import torch.nn.functional as F from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops -from ..mobilenetv3 import mobilenet_v3_large -from ..resnet import resnet50 +from ...transforms import ObjectDetectionEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from ..resnet import ResNet50_Weights, resnet50 from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor @@ -17,9 +22,12 @@ __all__ = [ "FasterRCNN", + "FasterRCNN_ResNet50_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", "fasterrcnn_resnet50_fpn", - "fasterrcnn_mobilenet_v3_large_320_fpn", "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", ] @@ -307,16 +315,70 @@ def forward(self, x): return scores, bbox_deltas -model_urls = { - "fasterrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - "fasterrcnn_mobilenet_v3_large_320_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - "fasterrcnn_mobilenet_v3_large_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", +_COMMON_META = { + "task": "image_object_detection", + "architecture": "FasterRCNN", + "publication_year": 2015, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, } +class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 41755286, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", + "map": 37.0, + }, + ) + DEFAULT = COCO_V1 + + +class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 19386354, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", + "map": 32.8, + }, + ) + DEFAULT = COCO_V1 + + +class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 19386354, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", + "map": 22.8, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def fasterrcnn_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. @@ -375,51 +437,60 @@ def fasterrcnn_resnet50_fpn( >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = FasterRCNN(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model def _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=False, - progress=True, - num_classes=91, - pretrained_backbone=True, - trainable_backbone_layers=None, - **kwargs, -): - is_trained = pretrained or pretrained_backbone + *, + weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], + progress: bool, + num_classes: Optional[int], + weights_backbone: Optional[MobileNet_V3_Large_Weights], + trainable_backbone_layers: Optional[int], + **kwargs: Any, +) -> FasterRCNN: + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - pretrained_backbone = False - - backbone = mobilenet_v3_large(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) - anchor_sizes = ( ( 32, @@ -430,21 +501,29 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ), ) * 3 aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - model = FasterRCNN( backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs ) - if pretrained: - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -459,15 +538,17 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco" + weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + defaults = { "min_size": 320, "max_size": 640, @@ -478,19 +559,28 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=pretrained, + weights=weights, progress=progress, num_classes=num_classes, - pretrained_backbone=pretrained_backbone, + weights_backbone=weights_backbone, trainable_backbone_layers=trainable_backbone_layers, **kwargs, ) +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def fasterrcnn_mobilenet_v3_large_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -505,26 +595,27 @@ def fasterrcnn_mobilenet_v3_large_fpn( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_MobileNet_V3_Large_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco" + weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + defaults = { "rpn_score_thresh": 0.05, } kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=pretrained, + weights=weights, progress=progress, num_classes=num_classes, - pretrained_backbone=pretrained_backbone, + weights_backbone=weights_backbone, trainable_backbone_layers=trainable_backbone_layers, **kwargs, ) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index c4c2e6f5842..8d110d809f7 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -2,25 +2,32 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Callable, Dict, List, Tuple, Optional +from typing import Any, Callable, Dict, List, Tuple, Optional import torch from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import sigmoid_focal_loss, generalized_box_iou_loss from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...transforms import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once -from ..resnet import resnet50 +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from . import _utils as det_utils from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform -__all__ = ["FCOS", "fcos_resnet50_fpn"] +__all__ = [ + "FCOS", + "FCOS_ResNet50_FPN_Weights", + "fcos_resnet50_fpn", +] class FCOSHead(nn.Module): @@ -626,19 +633,37 @@ def forward( return self.eager_outputs(losses, detections) -model_urls = { - "fcos_resnet50_fpn_coco": "https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", -} +class FCOS_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "FCOS", + "publication_year": 2019, + "num_params": 32269600, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", + "map": 39.2, + }, + ) + DEFAULT = COCO_V1 +@handle_legacy_interface( + weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def fcos_resnet50_fpn( - pretrained: bool = False, + *, + weights: Optional[FCOS_ResNet50_FPN_Weights] = None, progress: bool = True, - num_classes: int = 91, - pretrained_backbone: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, - **kwargs, -): + **kwargs: Any, +) -> FCOS: """ Constructs a FCOS model with a ResNet-50-FPN backbone. @@ -678,28 +703,34 @@ def fcos_resnet50_fpn( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FCOS_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. Default: None """ - is_trained = pretrained or pretrained_backbone + weights = FCOS_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) model = FCOS(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["fcos_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 9f23e66e0c5..3794b253ec7 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -1,16 +1,25 @@ +from typing import Any, Optional + import torch from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops -from ..resnet import resnet50 +from ...transforms import ObjectDetectionEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN -__all__ = ["KeypointRCNN", "keypointrcnn_resnet50_fpn"] +__all__ = [ + "KeypointRCNN", + "KeypointRCNN_ResNet50_FPN_Weights", + "keypointrcnn_resnet50_fpn", +] class KeypointRCNN(FasterRCNN): @@ -293,22 +302,61 @@ def forward(self, x): ) -model_urls = { - # legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606 - "keypointrcnn_resnet50_fpn_coco_legacy": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - "keypointrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", +_COMMON_META = { + "task": "image_object_detection", + "architecture": "KeypointRCNN", + "publication_year": 2017, + "categories": _COCO_PERSON_CATEGORIES, + "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES, + "interpolation": InterpolationMode.BILINEAR, } +class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_LEGACY = Weights( + url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 59137258, + "recipe": "https://github.com/pytorch/vision/issues/1606", + "map": 50.6, + "map_kp": 61.1, + }, + ) + COCO_V1 = Weights( + url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 59137258, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", + "map": 54.6, + "map_kp": 65.0, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY + if kwargs["pretrained"] == "legacy" + else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1, + ), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def keypointrcnn_resnet50_fpn( - pretrained=False, - progress=True, - num_classes=2, - num_keypoints=17, - pretrained_backbone=True, - trainable_backbone_layers=None, - **kwargs, -): + *, + weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + num_keypoints: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> KeypointRCNN: """ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. @@ -356,31 +404,39 @@ def keypointrcnn_resnet50_fpn( >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (KeypointRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - num_keypoints (int): number of keypoints, default 17 - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + num_keypoints (int, optional): number of keypoints + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"])) + else: + if num_classes is None: + num_classes = 2 + if num_keypoints is None: + num_keypoints = 17 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) - if pretrained: - key = "keypointrcnn_resnet50_fpn_coco" - if pretrained == "legacy": - key += "_legacy" - state_dict = load_state_dict_from_url(model_urls[key], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 37f88116c5e..38ba82af01d 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -1,17 +1,23 @@ from collections import OrderedDict +from typing import Any, Optional from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops -from ..resnet import resnet50 +from ...transforms import ObjectDetectionEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN + __all__ = [ "MaskRCNN", + "MaskRCNN_ResNet50_FPN_Weights", "maskrcnn_resnet50_fpn", ] @@ -296,14 +302,38 @@ def __init__(self, in_channels, dim_reduced, num_classes): # nn.init.constant_(param, 0) -model_urls = { - "maskrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", -} - - +class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "MaskRCNN", + "publication_year": 2017, + "num_params": 44401393, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", + "map": 37.9, + "map_mask": 34.6, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def maskrcnn_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> MaskRCNN: """ Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. @@ -352,27 +382,34 @@ def maskrcnn_resnet50_fpn( >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (MaskRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = MaskRCNN(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 4f79b5ddbfc..b1c371583bf 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -1,18 +1,21 @@ import math import warnings from collections import OrderedDict -from typing import Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional import torch from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import sigmoid_focal_loss from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...transforms import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once -from ..resnet import resnet50 +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from . import _utils as det_utils from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator @@ -20,7 +23,11 @@ from .transform import GeneralizedRCNNTransform -__all__ = ["RetinaNet", "retinanet_resnet50_fpn"] +__all__ = [ + "RetinaNet", + "RetinaNet_ResNet50_FPN_Weights", + "retinanet_resnet50_fpn", +] def _sum(x: List[Tensor]) -> Tensor: @@ -571,14 +578,37 @@ def forward(self, images, targets=None): return self.eager_outputs(losses, detections) -model_urls = { - "retinanet_resnet50_fpn_coco": "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", -} +class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "RetinaNet", + "publication_year": 2017, + "num_params": 34014999, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", + "map": 36.4, + }, + ) + DEFAULT = COCO_V1 +@handle_legacy_interface( + weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def retinanet_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> RetinaNet: """ Constructs a RetinaNet model with a ResNet-50-FPN backbone. @@ -618,30 +648,37 @@ def retinanet_resnet50_fpn( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (RetinaNet_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) # skip P2 because it generates too many anchors (according to their paper) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) model = RetinaNet(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index b7bbb81111e..ab901449f51 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -4,8 +4,7 @@ import torch.nn.functional as F import torchvision from torch import nn, Tensor -from torchvision.ops import boxes as box_ops -from torchvision.ops import roi_align +from torchvision.ops import boxes as box_ops, roi_align from . import _utils as det_utils diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 08a9ed68e4e..cf3becc5fc4 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -6,27 +6,42 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import boxes as box_ops +from ...transforms import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once -from .. import vgg +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..vgg import VGG, VGG16_Weights, vgg16 from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .transform import GeneralizedRCNNTransform -__all__ = ["SSD", "ssd300_vgg16"] -model_urls = { - "ssd300_vgg16_coco": "https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", -} - -backbone_urls = { - # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the - # same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth - # Only the `features` weights have proper values, those on the `classifier` module are filled with nans. - "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth" -} +__all__ = [ + "SSD300_VGG16_Weights", + "ssd300_vgg16", +] + + +class SSD300_VGG16_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "SSD", + "publication_year": 2015, + "num_params": 35641826, + "size": (300, 300), + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", + "map": 25.1, + }, + ) + DEFAULT = COCO_V1 def _xavier_init(conv: nn.Module): @@ -520,7 +535,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): +def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int): backbone = backbone.features # Gather the indices of maxpools. These are the locations of output blocks. stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] @@ -537,14 +552,19 @@ def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): return SSDFeatureExtractorVGG(backbone, highres) +@handle_legacy_interface( + weights=("pretrained", SSD300_VGG16_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES), +) def ssd300_vgg16( - pretrained: bool = False, + *, + weights: Optional[SSD300_VGG16_Weights] = None, progress: bool = True, - num_classes: int = 91, - pretrained_backbone: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[VGG16_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, -): +) -> SSD: """Constructs an SSD model with input size 300x300 and a VGG16 backbone. Reference: `"SSD: Single Shot MultiBox Detector" `_. @@ -582,31 +602,32 @@ def ssd300_vgg16( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (SSD300_VGG16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (VGG16_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 4. """ + weights = SSD300_VGG16_Weights.verify(weights) + weights_backbone = VGG16_Weights.verify(weights_backbone) + if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the argument.") + warnings.warn("The size of the model is already fixed; ignoring the parameter.") + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 4 + weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 ) - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - # Use custom backbones more appropriate for SSD - backbone = vgg.vgg16(pretrained=False, progress=progress) - if pretrained_backbone: - state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress) - backbone.load_state_dict(state_dict) - + backbone = vgg16(weights=weights_backbone, progress=progress) backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) anchor_generator = DefaultBoxGenerator( [[2], [2, 3], [2, 3], [2, 3], [2], [2]], @@ -619,12 +640,10 @@ def ssd300_vgg16( "image_mean": [0.48235, 0.45882, 0.40784], "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor } - kwargs = {**defaults, **kwargs} + kwargs: Any = {**defaults, **kwargs} model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) - if pretrained: - weights_name = "ssd300_vgg16_coco" - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 1ee59e069ea..a71da6b29ac 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -6,21 +6,24 @@ import torch from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation +from ...transforms import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once from .. import mobilenet +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .ssd import SSD, SSDScoringHead -__all__ = ["ssdlite320_mobilenet_v3_large"] - -model_urls = { - "ssdlite320_mobilenet_v3_large_coco": "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth" -} +__all__ = [ + "SSDLite320_MobileNet_V3_Large_Weights", + "ssdlite320_mobilenet_v3_large", +] # Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper @@ -178,15 +181,39 @@ def _mobilenet_extractor( return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer) +class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "SSDLite", + "publication_year": 2018, + "num_params": 3440060, + "size": (320, 320), + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", + "map": 21.3, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def ssdlite320_mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, progress: bool = True, - num_classes: int = 91, - pretrained_backbone: bool = False, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, trainable_backbone_layers: Optional[int] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any, -): +) -> SSD: """Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at `"Searching for MobileNetV3" `_ and @@ -203,35 +230,41 @@ def ssdlite320_mobilenet_v3_large( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 6. norm_layer (callable, optional): Module specifying the normalization layer to use. """ + weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the argument.") + warnings.warn("The size of the model is already fixed; ignoring the parameter.") + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6 + weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6 ) - if pretrained: - pretrained_backbone = False - # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper. - reduce_tail = not pretrained_backbone + reduce_tail = weights_backbone is None if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) - backbone = mobilenet.mobilenet_v3_large( - pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs + backbone = mobilenet_v3_large( + weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs ) - if not pretrained_backbone: + if weights_backbone is None: # Change the default initialization scheme if not pretrained _normal_init(backbone) backbone = _mobilenet_extractor( @@ -252,11 +285,11 @@ def ssdlite320_mobilenet_v3_large( "detections_per_img": 300, "topk_candidates": 300, # Rescale the input in a way compatible to the backbone: - # The following mean/std rescale the data from [0, 1] to [-1, 1] + # The following mean/std rescale the data from [0, 1] to [-1, -1] "image_mean": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5], } - kwargs = {**defaults, **kwargs} + kwargs: Any = {**defaults, **kwargs} model = SSD( backbone, anchor_generator, @@ -266,10 +299,7 @@ def ssdlite320_mobilenet_v3_large( **kwargs, ) - if pretrained: - weights_name = "ssdlite320_mobilenet_v3_large_coco" - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 5393827b293..27325c9016c 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -97,7 +97,8 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if weights.meta["categories"] is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 5988c160aad..3d7baca6284 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1,4 +1,3 @@ -from . import detection from . import optical_flow from . import segmentation from . import video diff --git a/torchvision/prototype/models/detection/__init__.py b/torchvision/prototype/models/detection/__init__.py deleted file mode 100644 index 4146651c737..00000000000 --- a/torchvision/prototype/models/detection/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .faster_rcnn import * -from .fcos import * -from .keypoint_rcnn import * -from .mask_rcnn import * -from .retinanet import * -from .ssd import * -from .ssdlite import * diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py deleted file mode 100644 index 5abc0eef1c4..00000000000 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ /dev/null @@ -1,226 +0,0 @@ -from typing import Any, Optional, Union - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.faster_rcnn import ( - _mobilenet_extractor, - _resnet_fpn_extractor, - _validate_trainable_layers, - AnchorGenerator, - FasterRCNN, - misc_nn_ops, - overwrite_eps, -) -from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "FasterRCNN", - "FasterRCNN_ResNet50_FPN_Weights", - "FasterRCNN_MobileNet_V3_Large_FPN_Weights", - "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", - "fasterrcnn_resnet50_fpn", - "fasterrcnn_mobilenet_v3_large_fpn", - "fasterrcnn_mobilenet_v3_large_320_fpn", -] - - -_COMMON_META = { - "task": "image_object_detection", - "architecture": "FasterRCNN", - "publication_year": 2015, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 41755286, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", - "map": 37.0, - }, - ) - DEFAULT = COCO_V1 - - -class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 19386354, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", - "map": 32.8, - }, - ) - DEFAULT = COCO_V1 - - -class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 19386354, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", - "map": 22.8, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def fasterrcnn_resnet50_fpn( - *, - weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FasterRCNN: - weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model - - -def _fasterrcnn_mobilenet_v3_large_fpn( - *, - weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], - progress: bool, - num_classes: Optional[int], - weights_backbone: Optional[MobileNet_V3_Large_Weights], - trainable_backbone_layers: Optional[int], - **kwargs: Any, -) -> FasterRCNN: - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) - anchor_sizes = ( - ( - 32, - 64, - 128, - 256, - 512, - ), - ) * 3 - aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - model = FasterRCNN( - backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def fasterrcnn_mobilenet_v3_large_fpn( - *, - weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FasterRCNN: - weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - defaults = { - "rpn_score_thresh": 0.05, - } - - kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn( - weights=weights, - progress=progress, - num_classes=num_classes, - weights_backbone=weights_backbone, - trainable_backbone_layers=trainable_backbone_layers, - **kwargs, - ) - - -@handle_legacy_interface( - weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def fasterrcnn_mobilenet_v3_large_320_fpn( - *, - weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FasterRCNN: - - weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - defaults = { - "min_size": 320, - "max_size": 640, - "rpn_pre_nms_top_n_test": 150, - "rpn_post_nms_top_n_test": 150, - "rpn_score_thresh": 0.05, - } - - kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn( - weights=weights, - progress=progress, - num_classes=num_classes, - weights_backbone=weights_backbone, - trainable_backbone_layers=trainable_backbone_layers, - **kwargs, - ) diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py deleted file mode 100644 index 930b26e46c8..00000000000 --- a/torchvision/prototype/models/detection/fcos.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.fcos import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - FCOS, - LastLevelP6P7, - misc_nn_ops, -) -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "FCOS", - "FCOS_ResNet50_FPN_Weights", - "fcos_resnet50_fpn", -] - - -class FCOS_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "FCOS", - "publication_year": 2019, - "num_params": 32269600, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", - "map": 39.2, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def fcos_resnet50_fpn( - *, - weights: Optional[FCOS_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FCOS: - weights = FCOS_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor( - backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) - ) - model = FCOS(backbone, num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py deleted file mode 100644 index a7780cc9f63..00000000000 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.keypoint_rcnn import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - KeypointRCNN, - misc_nn_ops, - overwrite_eps, -) -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "KeypointRCNN", - "KeypointRCNN_ResNet50_FPN_Weights", - "keypointrcnn_resnet50_fpn", -] - - -_COMMON_META = { - "task": "image_object_detection", - "architecture": "KeypointRCNN", - "publication_year": 2017, - "categories": _COCO_PERSON_CATEGORIES, - "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): - COCO_LEGACY = Weights( - url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 59137258, - "recipe": "https://github.com/pytorch/vision/issues/1606", - "map": 50.6, - "map_kp": 61.1, - }, - ) - COCO_V1 = Weights( - url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 59137258, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", - "map": 54.6, - "map_kp": 65.0, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY - if kwargs["pretrained"] == "legacy" - else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1, - ), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def keypointrcnn_resnet50_fpn( - *, - weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - num_keypoints: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> KeypointRCNN: - weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"])) - else: - if num_classes is None: - num_classes = 2 - if num_keypoints is None: - num_keypoints = 17 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py deleted file mode 100644 index d52ebe61be1..00000000000 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.mask_rcnn import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - MaskRCNN, - misc_nn_ops, - overwrite_eps, -) -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "MaskRCNN", - "MaskRCNN_ResNet50_FPN_Weights", - "maskrcnn_resnet50_fpn", -] - - -class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "MaskRCNN", - "publication_year": 2017, - "num_params": 44401393, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", - "map": 37.9, - "map_mask": 34.6, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def maskrcnn_resnet50_fpn( - *, - weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> MaskRCNN: - weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py deleted file mode 100644 index c4249118b70..00000000000 --- a/torchvision/prototype/models/detection/retinanet.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.retinanet import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - RetinaNet, - LastLevelP6P7, - misc_nn_ops, - overwrite_eps, -) -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "RetinaNet", - "RetinaNet_ResNet50_FPN_Weights", - "retinanet_resnet50_fpn", -] - - -class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "RetinaNet", - "publication_year": 2017, - "num_params": 34014999, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", - "map": 36.4, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def retinanet_resnet50_fpn( - *, - weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> RetinaNet: - weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - # skip P2 because it generates too many anchors (according to their paper) - backbone = _resnet_fpn_extractor( - backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) - ) - model = RetinaNet(backbone, num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py deleted file mode 100644 index a3c5b965deb..00000000000 --- a/torchvision/prototype/models/detection/ssd.py +++ /dev/null @@ -1,91 +0,0 @@ -import warnings -from typing import Any, Optional - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.ssd import ( - _validate_trainable_layers, - _vgg_extractor, - DefaultBoxGenerator, - SSD, -) -from torchvision.models.vgg import VGG16_Weights, vgg16 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "SSD300_VGG16_Weights", - "ssd300_vgg16", -] - - -class SSD300_VGG16_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "SSD", - "publication_year": 2015, - "num_params": 35641826, - "size": (300, 300), - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", - "map": 25.1, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", SSD300_VGG16_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES), -) -def ssd300_vgg16( - *, - weights: Optional[SSD300_VGG16_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[VGG16_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> SSD: - weights = SSD300_VGG16_Weights.verify(weights) - weights_backbone = VGG16_Weights.verify(weights_backbone) - - if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the parameter.") - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 - ) - - # Use custom backbones more appropriate for SSD - backbone = vgg16(weights=weights_backbone, progress=progress) - backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) - anchor_generator = DefaultBoxGenerator( - [[2], [2, 3], [2, 3], [2, 3], [2], [2]], - scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], - steps=[8, 16, 32, 64, 100, 300], - ) - - defaults = { - # Rescale the input in a way compatible to the backbone - "image_mean": [0.48235, 0.45882, 0.40784], - "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor - } - kwargs: Any = {**defaults, **kwargs} - model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py deleted file mode 100644 index d9f2ee58bc6..00000000000 --- a/torchvision/prototype/models/detection/ssdlite.py +++ /dev/null @@ -1,124 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Callable, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.ssdlite import ( - _mobilenet_extractor, - _normal_init, - _validate_trainable_layers, - DefaultBoxGenerator, - det_utils, - SSD, - SSDLiteHead, -) -from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "SSDLite320_MobileNet_V3_Large_Weights", - "ssdlite320_mobilenet_v3_large", -] - - -class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "SSDLite", - "publication_year": 2018, - "num_params": 3440060, - "size": (320, 320), - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", - "map": 21.3, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def ssdlite320_mobilenet_v3_large( - *, - weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any, -) -> SSD: - weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the parameter.") - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6 - ) - - # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper. - reduce_tail = weights_backbone is None - - if norm_layer is None: - norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) - - backbone = mobilenet_v3_large( - weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs - ) - if weights_backbone is None: - # Change the default initialization scheme if not pretrained - _normal_init(backbone) - backbone = _mobilenet_extractor( - backbone, - trainable_backbone_layers, - norm_layer, - ) - - size = (320, 320) - anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) - out_channels = det_utils.retrieve_out_channels(backbone, size) - num_anchors = anchor_generator.num_anchors_per_location() - assert len(out_channels) == len(anchor_generator.aspect_ratios) - - defaults = { - "score_thresh": 0.001, - "nms_thresh": 0.55, - "detections_per_img": 300, - "topk_candidates": 300, - # Rescale the input in a way compatible to the backbone: - # The following mean/std rescale the data from [0, 1] to [-1, -1] - "image_mean": [0.5, 0.5, 0.5], - "image_std": [0.5, 0.5, 0.5], - } - kwargs: Any = {**defaults, **kwargs} - model = SSD( - backbone, - anchor_generator, - size, - num_classes, - head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model From 5a96c9ac79bd55e41a5e4d13e12def142dd16088 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 15 Mar 2022 14:40:26 +0000 Subject: [PATCH 30/45] Porting Optical Flow, Segmentation, Video models (#5619) * Porting raft * Porting video resnet * Porting deeplabv3 * Porting fcn and lraspp * Fixing the tests and linter --- test/test_prototype_models.py | 94 +------ torchvision/models/optical_flow/__init__.py | 2 +- torchvision/models/optical_flow/raft.py | 191 +++++++++++--- torchvision/models/segmentation/__init__.py | 2 +- torchvision/models/segmentation/_utils.py | 8 - torchvision/models/segmentation/deeplabv3.py | 199 ++++++++++---- torchvision/models/segmentation/fcn.py | 139 +++++++--- torchvision/models/segmentation/lraspp.py | 78 ++++-- torchvision/models/video/resnet.py | 157 ++++++++--- torchvision/prototype/__init__.py | 1 - torchvision/prototype/models/__init__.py | 3 - .../prototype/models/optical_flow/__init__.py | 1 - .../prototype/models/optical_flow/raft.py | 249 ------------------ .../prototype/models/segmentation/__init__.py | 3 - .../models/segmentation/deeplabv3.py | 171 ------------ .../prototype/models/segmentation/fcn.py | 115 -------- .../prototype/models/segmentation/lraspp.py | 64 ----- .../prototype/models/video/__init__.py | 1 - torchvision/prototype/models/video/resnet.py | 150 ----------- 19 files changed, 579 insertions(+), 1049 deletions(-) delete mode 100644 torchvision/prototype/models/__init__.py delete mode 100644 torchvision/prototype/models/optical_flow/__init__.py delete mode 100644 torchvision/prototype/models/optical_flow/raft.py delete mode 100644 torchvision/prototype/models/segmentation/__init__.py delete mode 100644 torchvision/prototype/models/segmentation/deeplabv3.py delete mode 100644 torchvision/prototype/models/segmentation/fcn.py delete mode 100644 torchvision/prototype/models/segmentation/lraspp.py delete mode 100644 torchvision/prototype/models/video/__init__.py delete mode 100644 torchvision/prototype/models/video/resnet.py diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 06baea35fa8..65b8ffb9e40 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -3,12 +3,9 @@ import pytest import test_models as TM -import torch import torchvision -from common_utils import cpu_and_gpu, needs_cuda from torchvision.models._api import WeightsEnum, Weights from torchvision.models._utils import handle_legacy_interface -from torchvision.prototype import models run_if_test_with_prototype = pytest.mark.skipif( os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1", @@ -76,9 +73,9 @@ def test_get_weight(name, weight): TM.get_models_from_module(torchvision.models) + TM.get_models_from_module(torchvision.models.detection) + TM.get_models_from_module(torchvision.models.quantization) - + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video) - + TM.get_models_from_module(models.optical_flow), + + TM.get_models_from_module(torchvision.models.segmentation) + + TM.get_models_from_module(torchvision.models.video) + + TM.get_models_from_module(torchvision.models.optical_flow), ) def test_naming_conventions(model_fn): weights_enum = _get_model_weights(model_fn) @@ -92,9 +89,9 @@ def test_naming_conventions(model_fn): TM.get_models_from_module(torchvision.models) + TM.get_models_from_module(torchvision.models.detection) + TM.get_models_from_module(torchvision.models.quantization) - + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video) - + TM.get_models_from_module(models.optical_flow), + + TM.get_models_from_module(torchvision.models.segmentation) + + TM.get_models_from_module(torchvision.models.video) + + TM.get_models_from_module(torchvision.models.optical_flow), ) @run_if_test_with_prototype def test_schema_meta_validation(model_fn): @@ -143,85 +140,6 @@ def test_schema_meta_validation(model_fn): assert not bad_names -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_segmentation_model(model_fn, dev): - TM.test_segmentation_model(model_fn, dev) - - -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.video)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_video_model(model_fn, dev): - TM.test_video_model(model_fn, dev) - - -@needs_cuda -@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow)) -@pytest.mark.parametrize("scripted", (False, True)) -@run_if_test_with_prototype -def test_raft(model_builder, scripted): - TM.test_raft(model_builder, scripted) - - -@pytest.mark.parametrize( - "model_fn", - TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video) - + TM.get_models_from_module(models.optical_flow), -) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_old_vs_new_factory(model_fn, dev): - defaults = { - "models": { - "input_shape": (1, 3, 224, 224), - }, - "detection": { - "input_shape": (3, 300, 300), - }, - "quantization": { - "input_shape": (1, 3, 224, 224), - "quantize": True, - }, - "segmentation": { - "input_shape": (1, 3, 520, 520), - }, - "video": { - "input_shape": (1, 3, 4, 112, 112), - }, - "optical_flow": { - "input_shape": (1, 3, 128, 128), - }, - } - model_name = model_fn.__name__ - module_name = model_fn.__module__.split(".")[-2] - kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})} - input_shape = kwargs.pop("input_shape") - kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models - x = torch.rand(input_shape).to(device=dev) - if module_name == "detection": - x = [x] - - if module_name == "optical_flow": - args = [x, x] # RAFT model requires img1, img2 as input - else: - args = [x] - - # compare with new model builder parameterized in the old fashion way - try: - model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) - model_new = _build_model(model_fn, **kwargs).to(device=dev) - except ModuleNotFoundError: - pytest.skip(f"Model '{model_name}' not available in both modules.") - torch.testing.assert_close(model_new(*args), model_old(*args), rtol=0.0, atol=0.0, check_dtype=False) - - -def test_smoke(): - import torchvision.prototype.models # noqa: F401 - - # With this filter, every unexpected warning will be turned into an error @pytest.mark.filterwarnings("error") class TestHandleLegacyInterface: diff --git a/torchvision/models/optical_flow/__init__.py b/torchvision/models/optical_flow/__init__.py index 9dd32f25dec..89d2302f825 100644 --- a/torchvision/models/optical_flow/__init__.py +++ b/torchvision/models/optical_flow/__init__.py @@ -1 +1 @@ -from .raft import RAFT, raft_large, raft_small +from .raft import * diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 4dfd232d499..95469f78d3c 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -8,8 +8,10 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import Conv2dNormActivation -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import OpticalFlowEval, InterpolationMode from ...utils import _log_api_usage_once +from .._api import Weights, WeightsEnum +from .._utils import handle_legacy_interface from ._utils import grid_sample, make_coords_grid, upsample_flow @@ -17,15 +19,11 @@ "RAFT", "raft_large", "raft_small", + "Raft_Large_Weights", + "Raft_Small_Weights", ) -_MODELS_URLS = { - "raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", -} - - class ResidualBlock(nn.Module): """Slightly modified Residual block with extra relu and biases.""" @@ -500,10 +498,139 @@ def forward(self, image1, image2, num_flow_updates: int = 12): return flow_predictions +_COMMON_META = { + "task": "optical_flow", + "architecture": "RAFT", + "publication_year": 2020, + "interpolation": InterpolationMode.BILINEAR, +} + + +class Raft_Large_Weights(WeightsEnum): + C_T_V1 = Weights( + # Chairs + Things, ported from original paper repo (raft-things.pth) + url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 1.4411, + "sintel_train_finalpass_epe": 2.7894, + "kitti_train_per_image_epe": 5.0172, + "kitti_train_f1-all": 17.4506, + }, + ) + + C_T_V2 = Weights( + # Chairs + Things + url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_train_cleanpass_epe": 1.3822, + "sintel_train_finalpass_epe": 2.7161, + "kitti_train_per_image_epe": 4.5118, + "kitti_train_f1-all": 16.0679, + }, + ) + + C_T_SKHT_V1 = Weights( + # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_test_cleanpass_epe": 1.94, + "sintel_test_finalpass_epe": 3.18, + }, + ) + + C_T_SKHT_V2 = Weights( + # Chairs + Things + Sintel fine-tuning, i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_test_cleanpass_epe": 1.819, + "sintel_test_finalpass_epe": 3.067, + }, + ) + + C_T_SKHT_K_V1 = Weights( + # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/princeton-vl/RAFT", + "kitti_test_f1-all": 5.10, + }, + ) + + C_T_SKHT_K_V2 = Weights( + # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti + # Same as CT_SKHT with extra fine-tuning on Kitti + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti + url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 5257536, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "kitti_test_f1-all": 5.19, + }, + ) + + DEFAULT = C_T_SKHT_V2 + + +class Raft_Small_Weights(WeightsEnum): + C_T_V1 = Weights( + # Chairs + Things, ported from original paper repo (raft-small.pth) + url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/princeton-vl/RAFT", + "sintel_train_cleanpass_epe": 2.1231, + "sintel_train_finalpass_epe": 3.2790, + "kitti_train_per_image_epe": 7.6557, + "kitti_train_f1-all": 25.2801, + }, + ) + C_T_V2 = Weights( + # Chairs + Things + url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", + transforms=OpticalFlowEval, + meta={ + **_COMMON_META, + "num_params": 990162, + "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", + "sintel_train_cleanpass_epe": 1.9901, + "sintel_train_finalpass_epe": 3.2831, + "kitti_train_per_image_epe": 7.5978, + "kitti_train_f1-all": 25.2369, + }, + ) + + DEFAULT = C_T_V2 + + def _raft( *, - arch=None, - pretrained=False, + weights=None, progress=False, # Feature encoder feature_encoder_layers, @@ -577,38 +704,34 @@ def _raft( mask_predictor=mask_predictor, **kwargs, # not really needed, all params should be consumed by now ) - if pretrained: - state_dict = load_state_dict_from_url(_MODELS_URLS[arch], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) return model -def raft_large(*, pretrained=False, progress=True, **kwargs): +@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) +def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs) -> RAFT: """RAFT model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Please see the example below for a tutorial on how to use this model. Args: - pretrained (bool): Whether to use weights that have been pre-trained on - :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D` - with two fine-tuning steps: - - - one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D` - - one on :class:`~torchvsion.datasets.KittiFlow`. - - This corresponds to the ``C+T+S/K`` strategy in the paper. - - progress (bool): If True, displays a progress bar of the download to stderr. + weights(Raft_Large_weights, optional): The pretrained weights for the model + progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. Returns: - nn.Module: The model. + RAFT: The model. """ + weights = Raft_Large_Weights.verify(weights) + return _raft( - arch="raft_large", - pretrained=pretrained, + weights=weights, progress=progress, # Feature encoder feature_encoder_layers=(64, 64, 96, 128, 256), @@ -637,25 +760,27 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): ) -def raft_small(*, pretrained=False, progress=True, **kwargs): +@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) +def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs) -> RAFT: """RAFT "small" model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. Please see the example below for a tutorial on how to use this model. Args: - pretrained (bool): Whether to use weights that have been pre-trained on - :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`. + weights(Raft_Small_weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. Returns: - nn.Module: The model. + RAFT: The model. """ + weights = Raft_Small_Weights.verify(weights) return _raft( - arch="raft_small", - pretrained=pretrained, + weights=weights, progress=progress, # Feature encoder feature_encoder_layers=(32, 32, 64, 96, 128), diff --git a/torchvision/models/segmentation/__init__.py b/torchvision/models/segmentation/__init__.py index 1765502d693..3d6f37f958a 100644 --- a/torchvision/models/segmentation/__init__.py +++ b/torchvision/models/segmentation/__init__.py @@ -1,3 +1,3 @@ -from .fcn import * from .deeplabv3 import * +from .fcn import * from .lraspp import * diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index 0bbea5d3e81..44a60a95c54 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -4,7 +4,6 @@ from torch import nn, Tensor from torch.nn import functional as F -from ..._internally_replaced_utils import load_state_dict_from_url from ...utils import _log_api_usage_once @@ -36,10 +35,3 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: result["aux"] = x return result - - -def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None: - if model_url is None: - raise ValueError(f"No checkpoint is available for {arch}") - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 15ab5fffa5e..6e8bf0c398b 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -1,31 +1,31 @@ -from typing import List, Optional +from functools import partial +from typing import Any, List, Optional import torch from torch import nn from torch.nn import functional as F -from .. import mobilenetv3 -from .. import resnet -from .._utils import IntermediateLayerGetter -from ._utils import _SimpleSegmentationModel, _load_weights +from ...transforms import SemanticSegmentationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _VOC_CATEGORIES +from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large +from ..resnet import ResNet, resnet50, resnet101, ResNet50_Weights, ResNet101_Weights +from ._utils import _SimpleSegmentationModel from .fcn import FCNHead __all__ = [ "DeepLabV3", + "DeepLabV3_ResNet50_Weights", + "DeepLabV3_ResNet101_Weights", + "DeepLabV3_MobileNet_V3_Large_Weights", + "deeplabv3_mobilenet_v3_large", "deeplabv3_resnet50", "deeplabv3_resnet101", - "deeplabv3_mobilenet_v3_large", ] -model_urls = { - "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", -} - - class DeepLabV3(_SimpleSegmentationModel): """ Implements DeepLabV3 model from @@ -114,7 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _deeplabv3_resnet( - backbone: resnet.ResNet, + backbone: ResNet, num_classes: int, aux: Optional[bool], ) -> DeepLabV3: @@ -128,8 +128,62 @@ def _deeplabv3_resnet( return DeepLabV3(backbone, classifier, aux_classifier) +_COMMON_META = { + "task": "image_semantic_segmentation", + "architecture": "DeepLabV3", + "publication_year": 2017, + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class DeepLabV3_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 42004074, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", + "mIoU": 66.4, + "acc": 92.4, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 60996202, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", + "mIoU": 67.4, + "acc": 92.4, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 11029328, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", + "mIoU": 60.3, + "acc": 91.2, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + def _deeplabv3_mobilenetv3( - backbone: mobilenetv3.MobileNetV3, + backbone: MobileNetV3, num_classes: int, aux: Optional[bool], ) -> DeepLabV3: @@ -151,91 +205,124 @@ def _deeplabv3_mobilenetv3( return DeepLabV3(backbone, classifier, aux_classifier) +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def deeplabv3_resnet50( - pretrained: bool = False, + *, + weights: Optional[DeepLabV3_ResNet50_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet50_Weights] = None, + **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-50 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (DeepLabV3_ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = DeepLabV3_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) - backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "deeplabv3_resnet50_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) def deeplabv3_resnet101( - pretrained: bool = False, + *, + weights: Optional[DeepLabV3_ResNet101_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet101_Weights] = None, + **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-101 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (DeepLabV3_ResNet101_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr num_classes (int): The number of classes aux_loss (bool, optional): If True, include an auxiliary classifier - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = DeepLabV3_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) - backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "deeplabv3_resnet101_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def deeplabv3_mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (DeepLabV3_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) - if pretrained: - arch = "deeplabv3_mobilenet_v3_large_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 307781ebf00..5a3ca1f654f 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -1,19 +1,17 @@ -from typing import Optional +from functools import partial +from typing import Any, Optional from torch import nn -from .. import resnet -from .._utils import IntermediateLayerGetter -from ._utils import _SimpleSegmentationModel, _load_weights +from ...transforms import SemanticSegmentationEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _VOC_CATEGORIES +from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet, ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 +from ._utils import _SimpleSegmentationModel -__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"] - - -model_urls = { - "fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - "fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", -} +__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] class FCN(_SimpleSegmentationModel): @@ -49,8 +47,47 @@ def __init__(self, in_channels: int, channels: int) -> None: super().__init__(*layers) +_COMMON_META = { + "task": "image_semantic_segmentation", + "architecture": "FCN", + "publication_year": 2014, + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} + + +class FCN_ResNet50_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 35322218, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50", + "mIoU": 60.5, + "acc": 91.4, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +class FCN_ResNet101_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + **_COMMON_META, + "num_params": 54314346, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101", + "mIoU": 63.7, + "acc": 91.9, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + def _fcn_resnet( - backbone: resnet.ResNet, + backbone: ResNet, num_classes: int, aux: Optional[bool], ) -> FCN: @@ -64,61 +101,83 @@ def _fcn_resnet( return FCN(backbone, classifier, aux_classifier) +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def fcn_resnet50( - pretrained: bool = False, + *, + weights: Optional[FCN_ResNet50_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet50_Weights] = None, + **kwargs: Any, ) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (FCN_ResNet50_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = FCN_ResNet50_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 - backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _fcn_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "fcn_resnet50_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), +) def fcn_resnet101( - pretrained: bool = False, + *, + weights: Optional[FCN_ResNet101_Weights] = None, progress: bool = True, - num_classes: int = 21, + num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - pretrained_backbone: bool = True, + weights_backbone: Optional[ResNet101_Weights] = None, + **kwargs: Any, ) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (FCN_ResNet101_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) + num_classes (int, optional): number of output classes of the model (including the background) aux_loss (bool, optional): If True, it uses an auxiliary loss - pretrained_backbone (bool): If True, the backbone will be pre-trained. + weights_backbone (ResNet101_Weights, optional): The pretrained weights for the backbone """ - if pretrained: - aux_loss = True - pretrained_backbone = False + weights = FCN_ResNet101_Weights.verify(weights) + weights_backbone = ResNet101_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param(aux_loss, True) + elif num_classes is None: + num_classes = 21 - backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]) + backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) model = _fcn_resnet(backbone, num_classes, aux_loss) - if pretrained: - arch = "fcn_resnet101_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index ca73140661b..d1fe15a350d 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -1,21 +1,19 @@ from collections import OrderedDict -from typing import Any, Dict +from functools import partial +from typing import Any, Dict, Optional from torch import nn, Tensor from torch.nn import functional as F +from ...transforms import SemanticSegmentationEval, InterpolationMode from ...utils import _log_api_usage_once -from .. import mobilenetv3 -from .._utils import IntermediateLayerGetter -from ._utils import _load_weights +from .._api import WeightsEnum, Weights +from .._meta import _VOC_CATEGORIES +from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNetV3, MobileNet_V3_Large_Weights, mobilenet_v3_large -__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"] - - -model_urls = { - "lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", -} +__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] class LRASPP(nn.Module): @@ -30,7 +28,7 @@ class LRASPP(nn.Module): "high" for the high level feature map and "low" for the low level feature map. low_channels (int): the number of channels of the low level features. high_channels (int): the number of channels of the high level features. - num_classes (int): number of output classes of the model (including the background). + num_classes (int, optional): number of output classes of the model (including the background). inter_channels (int, optional): the number of channels for intermediate computations. """ @@ -81,7 +79,7 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor: return self.low_classifier(low) + self.high_classifier(x) -def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP: +def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP: backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. @@ -95,31 +93,61 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> return LRASPP(backbone, low_channels, high_channels, num_classes) +class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_WITH_VOC_LABELS_V1 = Weights( + url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", + transforms=partial(SemanticSegmentationEval, resize_size=520), + meta={ + "task": "image_semantic_segmentation", + "architecture": "LRASPP", + "publication_year": 2019, + "num_params": 3221538, + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", + "mIoU": 57.9, + "acc": 91.2, + }, + ) + DEFAULT = COCO_WITH_VOC_LABELS_V1 + + +@handle_legacy_interface( + weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def lraspp_mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, progress: bool = True, - num_classes: int = 21, - pretrained_backbone: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, **kwargs: Any, ) -> LRASPP: """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 which - contains the same classes as Pascal VOC + weights (LRASPP_MobileNet_V3_Large_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, the backbone will be pre-trained. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone """ if kwargs.pop("aux_loss", False): raise NotImplementedError("This model does not use auxiliary loss") - if pretrained: - pretrained_backbone = False - backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True) + weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 21 + + backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) model = _lraspp_mobilenetv3(backbone, num_classes) - if pretrained: - arch = "lraspp_mobilenet_v3_large_coco" - _load_weights(arch, model, model_urls.get(arch, None), progress) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index 4ac781a7c4c..a6b779d10f1 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -1,18 +1,25 @@ +from functools import partial from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union import torch.nn as nn from torch import Tensor -from ..._internally_replaced_utils import load_state_dict_from_url +from ...transforms import VideoClassificationEval, InterpolationMode from ...utils import _log_api_usage_once +from .._api import WeightsEnum, Weights +from .._meta import _KINETICS400_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["r3d_18", "mc3_18", "r2plus1d_18"] -model_urls = { - "r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - "mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - "r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", -} +__all__ = [ + "VideoResNet", + "R3D_18_Weights", + "MC3_18_Weights", + "R2Plus1D_18_Weights", + "r3d_18", + "mc3_18", + "r2plus1d_18", +] class Conv3DSimple(nn.Conv3d): @@ -281,80 +288,152 @@ def _make_layer( return nn.Sequential(*layers) -def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: - model = VideoResNet(**kwargs) +def _video_resnet( + block: Type[Union[BasicBlock, Bottleneck]], + conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], + layers: List[int], + stem: Callable[..., nn.Module], + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> VideoResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = VideoResNet(block, conv_makers, layers, stem, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) return model -def r3d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: +_COMMON_META = { + "task": "video_classification", + "publication_year": 2017, + "size": (112, 112), + "min_size": (1, 1), + "categories": _KINETICS400_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", +} + + +class R3D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", + transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "architecture": "R3D", + "num_params": 33371472, + "acc@1": 52.75, + "acc@5": 75.45, + }, + ) + DEFAULT = KINETICS400_V1 + + +class MC3_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", + transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "architecture": "MC3", + "num_params": 11695440, + "acc@1": 53.90, + "acc@5": 76.29, + }, + ) + DEFAULT = KINETICS400_V1 + + +class R2Plus1D_18_Weights(WeightsEnum): + KINETICS400_V1 = Weights( + url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", + transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + meta={ + **_COMMON_META, + "architecture": "R(2+1)D", + "num_params": 31505325, + "acc@1": 57.50, + "acc@5": 78.81, + }, + ) + DEFAULT = KINETICS400_V1 + + +@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) +def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Construct 18 layer Resnet3D model as in https://arxiv.org/abs/1711.11248 Args: - pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + weights (R3D_18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr Returns: - nn.Module: R3D-18 network + VideoResNet: R3D-18 network """ + weights = R3D_18_Weights.verify(weights) return _video_resnet( - "r3d_18", - pretrained, + BasicBlock, + [Conv3DSimple] * 4, + [2, 2, 2, 2], + BasicStem, + weights, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] * 4, - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs, ) -def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: +@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) +def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Constructor for 18 layer Mixed Convolution network as in https://arxiv.org/abs/1711.11248 Args: - pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + weights (MC3_18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr Returns: - nn.Module: MC3 Network definition + VideoResNet: MC3 Network definition """ + weights = MC3_18_Weights.verify(weights) + return _video_resnet( - "mc3_18", - pretrained, + BasicBlock, + [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] + [2, 2, 2, 2], + BasicStem, + weights, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs, ) -def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: +@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) +def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: """Constructor for the 18 layer deep R(2+1)D network as in https://arxiv.org/abs/1711.11248 Args: - pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr Returns: - nn.Module: R(2+1)D-18 network + VideoResNet: R(2+1)D-18 network """ + weights = R2Plus1D_18_Weights.verify(weights) + return _video_resnet( - "r2plus1d_18", - pretrained, + BasicBlock, + [Conv2Plus1D] * 4, + [2, 2, 2, 2], + R2Plus1dStem, + weights, progress, - block=BasicBlock, - conv_makers=[Conv2Plus1D] * 4, - layers=[2, 2, 2, 2], - stem=R2Plus1dStem, **kwargs, ) diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py index e1be6c81f59..bd35d31dcfd 100644 --- a/torchvision/prototype/__init__.py +++ b/torchvision/prototype/__init__.py @@ -1,5 +1,4 @@ from . import datasets from . import features -from . import models from . import transforms from . import utils diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py deleted file mode 100644 index 3d7baca6284..00000000000 --- a/torchvision/prototype/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import optical_flow -from . import segmentation -from . import video diff --git a/torchvision/prototype/models/optical_flow/__init__.py b/torchvision/prototype/models/optical_flow/__init__.py deleted file mode 100644 index 9b78f70b768..00000000000 --- a/torchvision/prototype/models/optical_flow/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py deleted file mode 100644 index 33e3243c2a0..00000000000 --- a/torchvision/prototype/models/optical_flow/raft.py +++ /dev/null @@ -1,249 +0,0 @@ -from typing import Optional - -from torch.nn.modules.batchnorm import BatchNorm2d -from torch.nn.modules.instancenorm import InstanceNorm2d -from torchvision.models._api import Weights -from torchvision.models._api import WeightsEnum -from torchvision.models._utils import handle_legacy_interface -from torchvision.models.optical_flow import RAFT -from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock -from torchvision.transforms import OpticalFlowEval, InterpolationMode - - -__all__ = ( - "RAFT", - "raft_large", - "raft_small", - "Raft_Large_Weights", - "Raft_Small_Weights", -) - - -_COMMON_META = { - "task": "optical_flow", - "architecture": "RAFT", - "publication_year": 2020, - "interpolation": InterpolationMode.BILINEAR, -} - - -class Raft_Large_Weights(WeightsEnum): - C_T_V1 = Weights( - # Chairs + Things, ported from original paper repo (raft-things.pth) - url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/princeton-vl/RAFT", - "sintel_train_cleanpass_epe": 1.4411, - "sintel_train_finalpass_epe": 2.7894, - "kitti_train_per_image_epe": 5.0172, - "kitti_train_f1-all": 17.4506, - }, - ) - - C_T_V2 = Weights( - # Chairs + Things - url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "sintel_train_cleanpass_epe": 1.3822, - "sintel_train_finalpass_epe": 2.7161, - "kitti_train_per_image_epe": 4.5118, - "kitti_train_f1-all": 16.0679, - }, - ) - - C_T_SKHT_V1 = Weights( - # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/princeton-vl/RAFT", - "sintel_test_cleanpass_epe": 1.94, - "sintel_test_finalpass_epe": 3.18, - }, - ) - - C_T_SKHT_V2 = Weights( - # Chairs + Things + Sintel fine-tuning, i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "sintel_test_cleanpass_epe": 1.819, - "sintel_test_finalpass_epe": 3.067, - }, - ) - - C_T_SKHT_K_V1 = Weights( - # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/princeton-vl/RAFT", - "kitti_test_f1-all": 5.10, - }, - ) - - C_T_SKHT_K_V2 = Weights( - # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti - # Same as CT_SKHT with extra fine-tuning on Kitti - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti - url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 5257536, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "kitti_test_f1-all": 5.19, - }, - ) - - DEFAULT = C_T_SKHT_V2 - - -class Raft_Small_Weights(WeightsEnum): - C_T_V1 = Weights( - # Chairs + Things, ported from original paper repo (raft-small.pth) - url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 990162, - "recipe": "https://github.com/princeton-vl/RAFT", - "sintel_train_cleanpass_epe": 2.1231, - "sintel_train_finalpass_epe": 3.2790, - "kitti_train_per_image_epe": 7.6557, - "kitti_train_f1-all": 25.2801, - }, - ) - C_T_V2 = Weights( - # Chairs + Things - url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", - transforms=OpticalFlowEval, - meta={ - **_COMMON_META, - "num_params": 990162, - "recipe": "https://github.com/pytorch/vision/tree/main/references/optical_flow", - "sintel_train_cleanpass_epe": 1.9901, - "sintel_train_finalpass_epe": 3.2831, - "kitti_train_per_image_epe": 7.5978, - "kitti_train_f1-all": 25.2369, - }, - ) - - DEFAULT = C_T_V2 - - -@handle_legacy_interface(weights=("pretrained", Raft_Large_Weights.C_T_SKHT_V2)) -def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): - """RAFT model from - `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. - - Args: - weights(Raft_Large_weights, optional): pretrained weights to use. - progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. - - Returns: - nn.Module: The model. - """ - - weights = Raft_Large_Weights.verify(weights) - - model = _raft( - # Feature encoder - feature_encoder_layers=(64, 64, 96, 128, 256), - feature_encoder_block=ResidualBlock, - feature_encoder_norm_layer=InstanceNorm2d, - # Context encoder - context_encoder_layers=(64, 64, 96, 128, 256), - context_encoder_block=ResidualBlock, - context_encoder_norm_layer=BatchNorm2d, - # Correlation block - corr_block_num_levels=4, - corr_block_radius=4, - # Motion encoder - motion_encoder_corr_layers=(256, 192), - motion_encoder_flow_layers=(128, 64), - motion_encoder_out_channels=128, - # Recurrent block - recurrent_block_hidden_state_size=128, - recurrent_block_kernel_size=((1, 5), (5, 1)), - recurrent_block_padding=((0, 2), (2, 0)), - # Flow head - flow_head_hidden_size=256, - # Mask predictor - use_mask_predictor=True, - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface(weights=("pretrained", Raft_Small_Weights.C_T_V2)) -def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): - """RAFT "small" model from - `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. - - Args: - weights(Raft_Small_weights, optional): pretrained weights to use. - progress (bool): If True, displays a progress bar of the download to stderr - kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class - to override any default. - - Returns: - nn.Module: The model. - - """ - - weights = Raft_Small_Weights.verify(weights) - - model = _raft( - # Feature encoder - feature_encoder_layers=(32, 32, 64, 96, 128), - feature_encoder_block=BottleneckBlock, - feature_encoder_norm_layer=InstanceNorm2d, - # Context encoder - context_encoder_layers=(32, 32, 64, 96, 160), - context_encoder_block=BottleneckBlock, - context_encoder_norm_layer=None, - # Correlation block - corr_block_num_levels=4, - corr_block_radius=3, - # Motion encoder - motion_encoder_corr_layers=(96,), - motion_encoder_flow_layers=(64, 32), - motion_encoder_out_channels=82, - # Recurrent block - recurrent_block_hidden_state_size=96, - recurrent_block_kernel_size=(3,), - recurrent_block_padding=(1,), - # Flow head - flow_head_hidden_size=128, - # Mask predictor - use_mask_predictor=False, - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - return model diff --git a/torchvision/prototype/models/segmentation/__init__.py b/torchvision/prototype/models/segmentation/__init__.py deleted file mode 100644 index 20273be2170..00000000000 --- a/torchvision/prototype/models/segmentation/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .fcn import * -from .lraspp import * -from .deeplabv3 import * diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py deleted file mode 100644 index 2c8d7f6ad84..00000000000 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ /dev/null @@ -1,171 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _VOC_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from torchvision.models.resnet import resnet50, resnet101, ResNet50_Weights, ResNet101_Weights -from torchvision.models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet -from torchvision.transforms import SemanticSegmentationEval, InterpolationMode - - -__all__ = [ - "DeepLabV3", - "DeepLabV3_ResNet50_Weights", - "DeepLabV3_ResNet101_Weights", - "DeepLabV3_MobileNet_V3_Large_Weights", - "deeplabv3_mobilenet_v3_large", - "deeplabv3_resnet50", - "deeplabv3_resnet101", -] - - -_COMMON_META = { - "task": "image_semantic_segmentation", - "architecture": "DeepLabV3", - "publication_year": 2017, - "categories": _VOC_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class DeepLabV3_ResNet50_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 42004074, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50", - "mIoU": 66.4, - "acc": 92.4, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -class DeepLabV3_ResNet101_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 60996202, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101", - "mIoU": 67.4, - "acc": 92.4, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 11029328, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_large", - "mIoU": 60.3, - "acc": 91.2, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -@handle_legacy_interface( - weights=("pretrained", DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def deeplabv3_resnet50( - *, - weights: Optional[DeepLabV3_ResNet50_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - **kwargs: Any, -) -> DeepLabV3: - weights = DeepLabV3_ResNet50_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), -) -def deeplabv3_resnet101( - *, - weights: Optional[DeepLabV3_ResNet101_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101_Weights] = None, - **kwargs: Any, -) -> DeepLabV3: - weights = DeepLabV3_ResNet101_Weights.verify(weights) - weights_backbone = ResNet101_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _deeplabv3_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def deeplabv3_mobilenet_v3_large( - *, - weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - **kwargs: Any, -) -> DeepLabV3: - weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) - model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py deleted file mode 100644 index e7b12621940..00000000000 --- a/torchvision/prototype/models/segmentation/fcn.py +++ /dev/null @@ -1,115 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _VOC_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 -from torchvision.models.segmentation.fcn import FCN, _fcn_resnet -from torchvision.transforms import SemanticSegmentationEval, InterpolationMode - - -__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] - - -_COMMON_META = { - "task": "image_semantic_segmentation", - "architecture": "FCN", - "publication_year": 2014, - "categories": _VOC_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class FCN_ResNet50_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 35322218, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet50", - "mIoU": 60.5, - "acc": 91.4, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -class FCN_ResNet101_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - **_COMMON_META, - "num_params": 54314346, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet101", - "mIoU": 63.7, - "acc": 91.9, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -@handle_legacy_interface( - weights=("pretrained", FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def fcn_resnet50( - *, - weights: Optional[FCN_ResNet50_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - **kwargs: Any, -) -> FCN: - weights = FCN_ResNet50_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet50(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _fcn_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", ResNet101_Weights.IMAGENET1K_V1), -) -def fcn_resnet101( - *, - weights: Optional[FCN_ResNet101_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101_Weights] = None, - **kwargs: Any, -) -> FCN: - weights = FCN_ResNet101_Weights.verify(weights) - weights_backbone = ResNet101_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) - elif num_classes is None: - num_classes = 21 - - backbone = resnet101(weights=weights_backbone, replace_stride_with_dilation=[False, True, True]) - model = _fcn_resnet(backbone, num_classes, aux_loss) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py deleted file mode 100644 index 21c15373089..00000000000 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ /dev/null @@ -1,64 +0,0 @@ -from functools import partial -from typing import Any, Optional - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _VOC_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from torchvision.models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 -from torchvision.transforms import SemanticSegmentationEval, InterpolationMode - - -__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] - - -class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): - COCO_WITH_VOC_LABELS_V1 = Weights( - url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), - meta={ - "task": "image_semantic_segmentation", - "architecture": "LRASPP", - "publication_year": 2019, - "num_params": 3221538, - "categories": _VOC_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", - "mIoU": 57.9, - "acc": 91.2, - }, - ) - DEFAULT = COCO_WITH_VOC_LABELS_V1 - - -@handle_legacy_interface( - weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def lraspp_mobilenet_v3_large( - *, - weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - **kwargs: Any, -) -> LRASPP: - if kwargs.pop("aux_loss", False): - raise NotImplementedError("This model does not use auxiliary loss") - - weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 21 - - backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True) - model = _lraspp_mobilenetv3(backbone, num_classes) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/video/__init__.py b/torchvision/prototype/models/video/__init__.py deleted file mode 100644 index b792ca6ecf7..00000000000 --- a/torchvision/prototype/models/video/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .resnet import * diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py deleted file mode 100644 index 0f4c0dd1dc9..00000000000 --- a/torchvision/prototype/models/video/resnet.py +++ /dev/null @@ -1,150 +0,0 @@ -from functools import partial -from typing import Any, Callable, List, Optional, Sequence, Type, Union - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _KINETICS400_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param -from torchvision.models.video.resnet import ( - BasicBlock, - BasicStem, - Bottleneck, - Conv2Plus1D, - Conv3DSimple, - Conv3DNoTemporal, - R2Plus1dStem, - VideoResNet, -) -from torchvision.transforms import VideoClassificationEval, InterpolationMode - - -__all__ = [ - "VideoResNet", - "R3D_18_Weights", - "MC3_18_Weights", - "R2Plus1D_18_Weights", - "r3d_18", - "mc3_18", - "r2plus1d_18", -] - - -def _video_resnet( - block: Type[Union[BasicBlock, Bottleneck]], - conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], - layers: List[int], - stem: Callable[..., nn.Module], - weights: Optional[WeightsEnum], - progress: bool, - **kwargs: Any, -) -> VideoResNet: - if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - - model = VideoResNet(block, conv_makers, layers, stem, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -_COMMON_META = { - "task": "video_classification", - "publication_year": 2017, - "size": (112, 112), - "min_size": (1, 1), - "categories": _KINETICS400_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification", -} - - -class R3D_18_Weights(WeightsEnum): - KINETICS400_V1 = Weights( - url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), - meta={ - **_COMMON_META, - "architecture": "R3D", - "num_params": 33371472, - "acc@1": 52.75, - "acc@5": 75.45, - }, - ) - DEFAULT = KINETICS400_V1 - - -class MC3_18_Weights(WeightsEnum): - KINETICS400_V1 = Weights( - url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), - meta={ - **_COMMON_META, - "architecture": "MC3", - "num_params": 11695440, - "acc@1": 53.90, - "acc@5": 76.29, - }, - ) - DEFAULT = KINETICS400_V1 - - -class R2Plus1D_18_Weights(WeightsEnum): - KINETICS400_V1 = Weights( - url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), - meta={ - **_COMMON_META, - "architecture": "R(2+1)D", - "num_params": 31505325, - "acc@1": 57.50, - "acc@5": 78.81, - }, - ) - DEFAULT = KINETICS400_V1 - - -@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1)) -def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - weights = R3D_18_Weights.verify(weights) - - return _video_resnet( - BasicBlock, - [Conv3DSimple] * 4, - [2, 2, 2, 2], - BasicStem, - weights, - progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1)) -def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - weights = MC3_18_Weights.verify(weights) - - return _video_resnet( - BasicBlock, - [Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] - [2, 2, 2, 2], - BasicStem, - weights, - progress, - **kwargs, - ) - - -@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1)) -def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: - weights = R2Plus1D_18_Weights.verify(weights) - - return _video_resnet( - BasicBlock, - [Conv2Plus1D] * 4, - [2, 2, 2, 2], - R2Plus1dStem, - weights, - progress, - **kwargs, - ) From 6d96ed50a3ed1cd49b9147711925784dae697ba2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 15 Mar 2022 18:07:46 +0000 Subject: [PATCH 31/45] Porting docs, examples, tutorials and galleries (#5620) * Fix examples, tutorials and gallery * Update gallery/plot_optical_flow.py Co-authored-by: Nicolas Hug * Fix import * Revert hardcoded normalization * fix uncommitted changes * Fix bug * Fix more bugs * Making resize optional for segmentation * Fixing preset * Fix mypy * Fixing documentation strings * Fix flake8 * minor refactoring Co-authored-by: Nicolas Hug --- android/test_app/make_assets.py | 13 +++++-- examples/cpp/hello_world/trace_model.py | 2 +- gallery/plot_optical_flow.py | 34 +++++++--------- gallery/plot_repurposing_annotations.py | 8 ++-- gallery/plot_scripted_tensor_transforms.py | 12 ++---- gallery/plot_visualization_utils.py | 39 +++++++++++++------ ios/VisionTestApp/make_assets.py | 13 +++++-- test/tracing/frcnn/trace_model.py | 2 +- torchvision/models/_utils.py | 2 +- .../models/detection/backbone_utils.py | 27 ++++++++----- torchvision/models/detection/faster_rcnn.py | 8 ++-- torchvision/models/detection/fcos.py | 4 +- torchvision/models/detection/keypoint_rcnn.py | 4 +- torchvision/models/detection/mask_rcnn.py | 4 +- torchvision/models/detection/retinanet.py | 4 +- torchvision/models/detection/ssd.py | 2 +- torchvision/models/detection/ssdlite.py | 2 +- torchvision/models/googlenet.py | 2 +- torchvision/models/inception.py | 2 +- torchvision/transforms/_presets.py | 12 +++--- 20 files changed, 115 insertions(+), 81 deletions(-) diff --git a/android/test_app/make_assets.py b/android/test_app/make_assets.py index fedee39fc52..f99933e9a9d 100644 --- a/android/test_app/make_assets.py +++ b/android/test_app/make_assets.py @@ -1,11 +1,18 @@ import torch -import torchvision from torch.utils.mobile_optimizer import optimize_for_mobile +from torchvision.models.detection import ( + fasterrcnn_mobilenet_v3_large_320_fpn, + FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, +) print(torch.__version__) -model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +model = fasterrcnn_mobilenet_v3_large_320_fpn( + weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT, + box_score_thresh=0.7, + rpn_post_nms_top_n_test=100, + rpn_score_thresh=0.4, + rpn_pre_nms_top_n_test=150, ) model.eval() diff --git a/examples/cpp/hello_world/trace_model.py b/examples/cpp/hello_world/trace_model.py index c8b8d6911e7..41bbaf8b6dd 100644 --- a/examples/cpp/hello_world/trace_model.py +++ b/examples/cpp/hello_world/trace_model.py @@ -6,7 +6,7 @@ HERE = osp.dirname(osp.abspath(__file__)) ASSETS = osp.dirname(osp.dirname(HERE)) -model = torchvision.models.resnet18(pretrained=False) +model = torchvision.models.resnet18() model.eval() traced_model = torch.jit.script(model) diff --git a/gallery/plot_optical_flow.py b/gallery/plot_optical_flow.py index 505334f36da..770610fb971 100644 --- a/gallery/plot_optical_flow.py +++ b/gallery/plot_optical_flow.py @@ -19,7 +19,6 @@ import torch import matplotlib.pyplot as plt import torchvision.transforms.functional as F -import torchvision.transforms as T plt.rcParams["savefig.bbox"] = "tight" @@ -88,24 +87,19 @@ def plot(imgs, **imshow_kwargs): # reduce the image sizes for the example to run faster. Image dimension must be # divisible by 8. +from torchvision.models.optical_flow import Raft_Large_Weights -def preprocess(batch): - transforms = T.Compose( - [ - T.ConvertImageDtype(torch.float32), - T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] - T.Resize(size=(520, 960)), - ] - ) - batch = transforms(batch) - return batch +weights = Raft_Large_Weights.DEFAULT +transforms = weights.transforms() -# If you can, run this example on a GPU, it will be a lot faster. -device = "cuda" if torch.cuda.is_available() else "cpu" +def preprocess(img1_batch, img2_batch): + img1_batch = F.resize(img1_batch, size=[520, 960]) + img2_batch = F.resize(img2_batch, size=[520, 960]) + return transforms(img1_batch, img2_batch)[:2] + -img1_batch = preprocess(img1_batch).to(device) -img2_batch = preprocess(img2_batch).to(device) +img1_batch, img2_batch = preprocess(img1_batch, img2_batch) print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}") @@ -121,7 +115,10 @@ def preprocess(batch): from torchvision.models.optical_flow import raft_large -model = raft_large(pretrained=True, progress=False).to(device) +# If you can, run this example on a GPU, it will be a lot faster. +device = "cuda" if torch.cuda.is_available() else "cpu" + +model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) model = model.eval() list_of_flows = model(img1_batch.to(device), img2_batch.to(device)) @@ -182,10 +179,9 @@ def preprocess(batch): # from torchvision.io import write_jpeg # for i, (img1, img2) in enumerate(zip(frames, frames[1:])): # # Note: it would be faster to predict batches of flows instead of individual flows -# img1 = preprocess(img1[None]).to(device) -# img2 = preprocess(img2[None]).to(device) +# img1, img2 = preprocess(img1, img2) -# list_of_flows = model(img1_batch, img2_batch) +# list_of_flows = model(img1.to(device), img1.to(device)) # predicted_flow = list_of_flows[-1][0] # flow_img = flow_to_image(predicted_flow).to("cpu") # output_folder = "/tmp/" # Update this to the folder of your choice diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py index fb4835496c3..a826a2523f2 100644 --- a/gallery/plot_repurposing_annotations.py +++ b/gallery/plot_repurposing_annotations.py @@ -139,12 +139,14 @@ def show(imgs): # Here is demo with a Faster R-CNN model loaded from # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` -from torchvision.models.detection import fasterrcnn_resnet50_fpn +from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights -model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) +weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT +model = fasterrcnn_resnet50_fpn(weights=weights, progress=False) print(img.size()) -img = F.convert_image_dtype(img, torch.float) +tranforms = weights.transforms() +img, _ = tranforms(img) target = {} target["boxes"] = boxes target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64) diff --git a/gallery/plot_scripted_tensor_transforms.py b/gallery/plot_scripted_tensor_transforms.py index a9205536821..995383d4603 100644 --- a/gallery/plot_scripted_tensor_transforms.py +++ b/gallery/plot_scripted_tensor_transforms.py @@ -85,20 +85,16 @@ def show(imgs): # Let's define a ``Predictor`` module that transforms the input tensor and then # applies an ImageNet model on it. -from torchvision.models import resnet18 +from torchvision.models import resnet18, ResNet18_Weights class Predictor(nn.Module): def __init__(self): super().__init__() - self.resnet18 = resnet18(pretrained=True, progress=False).eval() - self.transforms = nn.Sequential( - T.Resize([256, ]), # We use single int value inside a list due to torchscript type restrictions - T.CenterCrop(224), - T.ConvertImageDtype(torch.float), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ) + weights = ResNet18_Weights.DEFAULT + self.resnet18 = resnet18(weights=weights, progress=False).eval() + self.transforms = weights.transforms() def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 526c8c32493..27fd97681c0 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -73,14 +73,17 @@ def show(imgs): # :func:`~torchvision.models.detection.ssd300_vgg16`. For more details # on the output of such models, you may refer to :ref:`instance_seg_output`. -from torchvision.models.detection import fasterrcnn_resnet50_fpn -from torchvision.transforms.functional import convert_image_dtype +from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights batch_int = torch.stack([dog1_int, dog2_int]) -batch = convert_image_dtype(batch_int, dtype=torch.float) -model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) +weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT +transforms = weights.transforms() + +batch, _ = transforms(batch_int) + +model = fasterrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() outputs = model(batch) @@ -120,13 +123,15 @@ def show(imgs): # images must be normalized before they're passed to a semantic segmentation # model. -from torchvision.models.segmentation import fcn_resnet50 +from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights +weights = FCN_ResNet50_Weights.DEFAULT +transforms = weights.transforms(resize_size=None) -model = fcn_resnet50(pretrained=True, progress=False) +model = fcn_resnet50(weights=weights, progress=False) model = model.eval() -normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) +normalized_batch, _ = transforms(batch) output = model(normalized_batch)['out'] print(output.shape, output.min().item(), output.max().item()) @@ -262,8 +267,14 @@ def show(imgs): # of them may not have masks, like # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`. -from torchvision.models.detection import maskrcnn_resnet50_fpn -model = maskrcnn_resnet50_fpn(pretrained=True, progress=False) +from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights + +weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT +transforms = weights.transforms() + +batch, _ = transforms(batch_int) + +model = maskrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() output = model(batch) @@ -378,13 +389,17 @@ def show(imgs): # Note that the keypoint detection model does not need normalized images. # -from torchvision.models.detection import keypointrcnn_resnet50_fpn +from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights from torchvision.io import read_image person_int = read_image(str(Path("assets") / "person1.jpg")) -person_float = convert_image_dtype(person_int, dtype=torch.float) -model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False) +weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT +transforms = weights.transforms() + +person_float, _ = transforms(person_int) + +model = keypointrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() outputs = model([person_float]) diff --git a/ios/VisionTestApp/make_assets.py b/ios/VisionTestApp/make_assets.py index 0f46364569b..f14223e6a42 100644 --- a/ios/VisionTestApp/make_assets.py +++ b/ios/VisionTestApp/make_assets.py @@ -1,11 +1,18 @@ import torch -import torchvision from torch.utils.mobile_optimizer import optimize_for_mobile +from torchvision.models.detection import ( + fasterrcnn_mobilenet_v3_large_320_fpn, + FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, +) print(torch.__version__) -model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +model = fasterrcnn_mobilenet_v3_large_320_fpn( + weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT, + box_score_thresh=0.7, + rpn_post_nms_top_n_test=100, + rpn_score_thresh=0.4, + rpn_pre_nms_top_n_test=150, ) model.eval() diff --git a/test/tracing/frcnn/trace_model.py b/test/tracing/frcnn/trace_model.py index 8cc1d344936..768954d29b2 100644 --- a/test/tracing/frcnn/trace_model.py +++ b/test/tracing/frcnn/trace_model.py @@ -6,7 +6,7 @@ HERE = osp.dirname(osp.abspath(__file__)) ASSETS = osp.dirname(osp.dirname(HERE)) -model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False) +model = torchvision.models.detection.fasterrcnn_resnet50_fpn() model.eval() traced_model = torch.jit.script(model) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 9e3a81411a1..08c878a8a67 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -32,7 +32,7 @@ class IntermediateLayerGetter(nn.ModuleDict): Examples:: - >>> m = torchvision.models.resnet18(pretrained=True) + >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT) >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, >>> {'layer1': 'feat1', 'layer3': 'feat2'}) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index cac96b61f64..b767756692b 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -6,7 +6,8 @@ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from .. import mobilenet, resnet -from .._utils import IntermediateLayerGetter +from .._api import WeightsEnum +from .._utils import IntermediateLayerGetter, handle_legacy_interface class BackboneWithFPN(nn.Module): @@ -55,9 +56,13 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return x +@handle_legacy_interface( + weights=("pretrained", True), # type: ignore[arg-type] +) def resnet_fpn_backbone( + *, backbone_name: str, - pretrained: bool, + weights: Optional[WeightsEnum], norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 3, returned_layers: Optional[List[int]] = None, @@ -69,7 +74,7 @@ def resnet_fpn_backbone( Examples:: >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone - >>> backbone = resnet_fpn_backbone('resnet50', pretrained=True, trainable_layers=3) + >>> backbone = resnet_fpn_backbone('resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3) >>> # get some dummy image >>> x = torch.rand(1,3,64,64) >>> # compute the output @@ -85,7 +90,7 @@ def resnet_fpn_backbone( Args: backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2' - pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet + weights (WeightsEnum, optional): The pretrained weights for the model norm_layer (callable): it is recommended to use the default value. For details visit: (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) trainable_layers (int): number of trainable (not frozen) layers starting from final block. @@ -98,7 +103,7 @@ def resnet_fpn_backbone( a new list of feature maps and their corresponding names. By default a ``LastLevelMaxPool`` is used. """ - backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) @@ -135,13 +140,13 @@ def _resnet_fpn_extractor( def _validate_trainable_layers( - pretrained: bool, + is_trained: bool, trainable_backbone_layers: Optional[int], max_value: int, default_value: int, ) -> int: # don't freeze any layers if pretrained model or backbone is not used - if not pretrained: + if not is_trained: if trainable_backbone_layers is not None: warnings.warn( "Changing trainable_backbone_layers has not effect if " @@ -160,16 +165,20 @@ def _validate_trainable_layers( return trainable_backbone_layers +@handle_legacy_interface( + weights=("pretrained", True), # type: ignore[arg-type] +) def mobilenet_backbone( + *, backbone_name: str, - pretrained: bool, + weights: Optional[WeightsEnum], fpn: bool, norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 2, returned_layers: Optional[List[int]] = None, extra_blocks: Optional[ExtraFPNBlock] = None, ) -> nn.Module: - backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 18872adc029..efbd88906f2 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -117,7 +117,7 @@ class FasterRCNN(GeneralizedRCNN): >>> from torchvision.models.detection.rpn import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # FasterRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -415,7 +415,7 @@ def fasterrcnn_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) >>> # For training >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4) >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4] @@ -532,7 +532,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( Example:: - >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) @@ -589,7 +589,7 @@ def fasterrcnn_mobilenet_v3_large_fpn( Example:: - >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 8d110d809f7..5627573836a 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -325,7 +325,7 @@ class FCOS(nn.Module): >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # FCOS needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -697,7 +697,7 @@ def fcos_resnet50_fpn( Example: - >>> model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 3794b253ec7..272a6c3debe 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -119,7 +119,7 @@ class KeypointRCNN(FasterRCNN): >>> >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # KeypointRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -395,7 +395,7 @@ def keypointrcnn_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 38ba82af01d..04652d5a66f 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -118,7 +118,7 @@ class MaskRCNN(FasterRCNN): >>> >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # MaskRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -373,7 +373,7 @@ def maskrcnn_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index b1c371583bf..da3a521c36a 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -293,7 +293,7 @@ class RetinaNet(nn.Module): >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # RetinaNet needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -642,7 +642,7 @@ def retinanet_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index cf3becc5fc4..32158ecfab3 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -596,7 +596,7 @@ def ssd300_vgg16( Example: - >>> model = torchvision.models.detection.ssd300_vgg16(pretrained=True) + >>> model = torchvision.models.detection.ssd300_vgg16(weights=SSD300_VGG16_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index a71da6b29ac..bb471fa6fa8 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -224,7 +224,7 @@ def ssdlite320_mobilenet_v3_large( Example: - >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) + >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 2cac4a4fbbd..a47b478af03 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -307,7 +307,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T aux_logits (bool): If True, adds two auxiliary branches that can improve training. Default: *False* when pretrained is True otherwise *True* transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. + was trained on ImageNet. Default: True if ``weights=GoogLeNet_Weights.IMAGENET1K_V1``, else False. """ weights = GoogLeNet_Weights.verify(weights) diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 1628542482b..44b2bd56feb 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -443,7 +443,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True* transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. + was trained on ImageNet. Default: True if ``weights=Inception_V3_Weights.IMAGENET1K_V1``, else False. """ weights = Inception_V3_Weights.verify(weights) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index a6b85d05597..1776d876ccb 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -78,27 +78,29 @@ def forward(self, vid: Tensor) -> Tensor: class SemanticSegmentationEval(nn.Module): def __init__( self, - resize_size: int, + resize_size: Optional[int], mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation_target: InterpolationMode = InterpolationMode.NEAREST, ) -> None: super().__init__() - self._size = [resize_size] + self._size = [resize_size] if resize_size is not None else None self._mean = list(mean) self._std = list(std) self._interpolation = interpolation self._interpolation_target = interpolation_target def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: - img = F.resize(img, self._size, interpolation=self._interpolation) + if isinstance(self._size, list): + img = F.resize(img, self._size, interpolation=self._interpolation) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self._mean, std=self._std) if target: - target = F.resize(target, self._size, interpolation=self._interpolation_target) + if isinstance(self._size, list): + target = F.resize(target, self._size, interpolation=self._interpolation_target) if not isinstance(target, Tensor): target = F.pil_to_tensor(target) target = target.squeeze(0).to(torch.int64) @@ -107,7 +109,7 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, class OpticalFlowEval(nn.Module): def forward( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] + self, img1: Tensor, img2: Tensor, flow: Optional[Tensor] = None, valid_flow_mask: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask) From c88b8dcebb7523a9726eb979179b6b6d7197c5b5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 15 Mar 2022 18:40:20 +0000 Subject: [PATCH 32/45] Resolve conflict --- torchvision/prototype/models/detection/ssdlite.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 torchvision/prototype/models/detection/ssdlite.py diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py deleted file mode 100644 index e69de29bb2d..00000000000 From f121ca7e5208f816f31ad6441ec67c6078120082 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 15 Mar 2022 21:21:57 +0000 Subject: [PATCH 33/45] Porting model tests (#5622) * Porting tests * Remove unnecessary variable * Fix linter * Move prototype to extended tests * Fix download models job --- .circleci/config.yml | 27 ++++++--- .circleci/config.yml.in | 27 ++++++--- test/test_backbone_utils.py | 20 +++---- test/test_cpp_models.py | 59 +++++++++---------- ...type_models.py => test_extended_models.py} | 57 +++++++----------- test/test_hub.py | 4 +- test/test_models.py | 13 ++-- .../test_models_detection_negative_samples.py | 22 ++----- test/test_models_detection_utils.py | 10 ++-- test/test_onnx.py | 14 +++-- .../models/detection/backbone_utils.py | 6 +- 11 files changed, 125 insertions(+), 134 deletions(-) rename test/{test_prototype_models.py => test_extended_models.py} (78%) diff --git a/.circleci/config.yml b/.circleci/config.yml index 3de0894304b..254e758ade2 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -335,6 +335,20 @@ jobs: file_or_dir: test/test_onnx.py unittest_prototype: + docker: + - image: circleci/python:3.7 + resource_class: xlarge + steps: + - checkout + - install_torchvision + - install_prototype_dependencies + - pip_install: + args: scipy pycocotools h5py + descr: Install optional dependencies + - run_tests_selective: + file_or_dir: test/test_prototype_*.py + + unittest_extended: docker: - image: circleci/python:3.7 resource_class: xlarge @@ -346,18 +360,14 @@ jobs: command: | sudo apt update -qy && sudo apt install -qy parallel wget mkdir -p ~/.cache/torch/hub/checkpoints - python scripts/collect_model_urls.py torchvision/prototype/models \ + python scripts/collect_model_urls.py torchvision/models \ | parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci' - install_torchvision - - install_prototype_dependencies - - pip_install: - args: scipy pycocotools h5py - descr: Install optional dependencies - run: - name: Enable prototype tests - command: echo 'export PYTORCH_TEST_WITH_PROTOTYPE=1' >> $BASH_ENV + name: Enable extended tests + command: echo 'export PYTORCH_TEST_WITH_EXTENDED=1' >> $BASH_ENV - run_tests_selective: - file_or_dir: test/test_prototype_*.py + file_or_dir: test/test_extended_*.py binary_linux_wheel: <<: *binary_common @@ -1608,6 +1618,7 @@ workflows: - unittest_torchhub - unittest_onnx - unittest_prototype + - unittest_extended - unittest_linux_cpu: cu_version: cpu name: unittest_linux_cpu_py3.7 diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index e36b368db1c..6b5b31a3e6b 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -335,6 +335,20 @@ jobs: file_or_dir: test/test_onnx.py unittest_prototype: + docker: + - image: circleci/python:3.7 + resource_class: xlarge + steps: + - checkout + - install_torchvision + - install_prototype_dependencies + - pip_install: + args: scipy pycocotools h5py + descr: Install optional dependencies + - run_tests_selective: + file_or_dir: test/test_prototype_*.py + + unittest_extended: docker: - image: circleci/python:3.7 resource_class: xlarge @@ -346,18 +360,14 @@ jobs: command: | sudo apt update -qy && sudo apt install -qy parallel wget mkdir -p ~/.cache/torch/hub/checkpoints - python scripts/collect_model_urls.py torchvision/prototype/models \ + python scripts/collect_model_urls.py torchvision/models \ | parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci' - install_torchvision - - install_prototype_dependencies - - pip_install: - args: scipy pycocotools h5py - descr: Install optional dependencies - run: - name: Enable prototype tests - command: echo 'export PYTORCH_TEST_WITH_PROTOTYPE=1' >> $BASH_ENV + name: Enable extended tests + command: echo 'export PYTORCH_TEST_WITH_EXTENDED=1' >> $BASH_ENV - run_tests_selective: - file_or_dir: test/test_prototype_*.py + file_or_dir: test/test_extended_*.py binary_linux_wheel: <<: *binary_common @@ -1094,6 +1104,7 @@ workflows: - unittest_torchhub - unittest_onnx - unittest_prototype + - unittest_extended {{ unittest_workflows() }} cmake: diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 2c9b15d2a60..bee52c06075 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -23,30 +23,30 @@ def get_available_models(): @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) def test_resnet_fpn_backbone(backbone_name): x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") - model = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False) + model = resnet_fpn_backbone(backbone_name=backbone_name) assert isinstance(model, BackboneWithFPN) y = model(x) assert list(y.keys()) == ["0", "1", "2", "3", "pool"] with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): - resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False, trainable_layers=6) + resnet_fpn_backbone(backbone_name=backbone_name, trainable_layers=6) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - resnet_fpn_backbone(backbone_name, False, returned_layers=[0, 1, 2, 3]) + resnet_fpn_backbone(backbone_name=backbone_name, returned_layers=[0, 1, 2, 3]) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - resnet_fpn_backbone(backbone_name, False, returned_layers=[2, 3, 4, 5]) + resnet_fpn_backbone(backbone_name=backbone_name, returned_layers=[2, 3, 4, 5]) @pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small")) def test_mobilenet_backbone(backbone_name): with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): - mobilenet_backbone(backbone_name=backbone_name, pretrained=False, fpn=False, trainable_layers=-1) + mobilenet_backbone(backbone_name=backbone_name, fpn=False, trainable_layers=-1) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2]) + mobilenet_backbone(backbone_name=backbone_name, fpn=True, returned_layers=[-1, 0, 1, 2]) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6]) - model_fpn = mobilenet_backbone(backbone_name, False, fpn=True) + mobilenet_backbone(backbone_name=backbone_name, fpn=True, returned_layers=[3, 4, 5, 6]) + model_fpn = mobilenet_backbone(backbone_name=backbone_name, fpn=True) assert isinstance(model_fpn, BackboneWithFPN) - model = mobilenet_backbone(backbone_name, False, fpn=False) + model = mobilenet_backbone(backbone_name=backbone_name, fpn=False) assert isinstance(model, torch.nn.Sequential) @@ -100,7 +100,7 @@ def forward(self, x): class TestFxFeatureExtraction: inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu") - model_defaults = {"num_classes": 1, "pretrained": False} + model_defaults = {"num_classes": 1} leaf_modules = [] def _create_feature_extractor(self, *args, **kwargs): diff --git a/test/test_cpp_models.py b/test/test_cpp_models.py index f7cce7b6c43..d8d0836d499 100644 --- a/test/test_cpp_models.py +++ b/test/test_cpp_models.py @@ -53,50 +53,49 @@ def read_image2(): "see https://github.com/pytorch/vision/issues/1191", ) class Tester(unittest.TestCase): - pretrained = False image = read_image1() def test_alexnet(self): - process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, "Alexnet") + process_model(models.alexnet(), self.image, _C_tests.forward_alexnet, "Alexnet") def test_vgg11(self): - process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, "VGG11") + process_model(models.vgg11(), self.image, _C_tests.forward_vgg11, "VGG11") def test_vgg13(self): - process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, "VGG13") + process_model(models.vgg13(), self.image, _C_tests.forward_vgg13, "VGG13") def test_vgg16(self): - process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, "VGG16") + process_model(models.vgg16(), self.image, _C_tests.forward_vgg16, "VGG16") def test_vgg19(self): - process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, "VGG19") + process_model(models.vgg19(), self.image, _C_tests.forward_vgg19, "VGG19") def test_vgg11_bn(self): - process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, "VGG11BN") + process_model(models.vgg11_bn(), self.image, _C_tests.forward_vgg11bn, "VGG11BN") def test_vgg13_bn(self): - process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, "VGG13BN") + process_model(models.vgg13_bn(), self.image, _C_tests.forward_vgg13bn, "VGG13BN") def test_vgg16_bn(self): - process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, "VGG16BN") + process_model(models.vgg16_bn(), self.image, _C_tests.forward_vgg16bn, "VGG16BN") def test_vgg19_bn(self): - process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, "VGG19BN") + process_model(models.vgg19_bn(), self.image, _C_tests.forward_vgg19bn, "VGG19BN") def test_resnet18(self): - process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, "Resnet18") + process_model(models.resnet18(), self.image, _C_tests.forward_resnet18, "Resnet18") def test_resnet34(self): - process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, "Resnet34") + process_model(models.resnet34(), self.image, _C_tests.forward_resnet34, "Resnet34") def test_resnet50(self): - process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, "Resnet50") + process_model(models.resnet50(), self.image, _C_tests.forward_resnet50, "Resnet50") def test_resnet101(self): - process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, "Resnet101") + process_model(models.resnet101(), self.image, _C_tests.forward_resnet101, "Resnet101") def test_resnet152(self): - process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, "Resnet152") + process_model(models.resnet152(), self.image, _C_tests.forward_resnet152, "Resnet152") def test_resnext50_32x4d(self): process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d") @@ -111,48 +110,44 @@ def test_wide_resnet101_2(self): process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2") def test_squeezenet1_0(self): - process_model( - models.squeezenet1_0(self.pretrained), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0" - ) + process_model(models.squeezenet1_0(), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0") def test_squeezenet1_1(self): - process_model( - models.squeezenet1_1(self.pretrained), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1" - ) + process_model(models.squeezenet1_1(), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1") def test_densenet121(self): - process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, "Densenet121") + process_model(models.densenet121(), self.image, _C_tests.forward_densenet121, "Densenet121") def test_densenet169(self): - process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, "Densenet169") + process_model(models.densenet169(), self.image, _C_tests.forward_densenet169, "Densenet169") def test_densenet201(self): - process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, "Densenet201") + process_model(models.densenet201(), self.image, _C_tests.forward_densenet201, "Densenet201") def test_densenet161(self): - process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, "Densenet161") + process_model(models.densenet161(), self.image, _C_tests.forward_densenet161, "Densenet161") def test_mobilenet_v2(self): - process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, "MobileNet") + process_model(models.mobilenet_v2(), self.image, _C_tests.forward_mobilenetv2, "MobileNet") def test_googlenet(self): - process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, "GoogLeNet") + process_model(models.googlenet(), self.image, _C_tests.forward_googlenet, "GoogLeNet") def test_mnasnet0_5(self): - process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5") + process_model(models.mnasnet0_5(), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5") def test_mnasnet0_75(self): - process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75") + process_model(models.mnasnet0_75(), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75") def test_mnasnet1_0(self): - process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0") + process_model(models.mnasnet1_0(), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0") def test_mnasnet1_3(self): - process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3") + process_model(models.mnasnet1_3(), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3") def test_inception_v3(self): self.image = read_image2() - process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, "Inceptionv3") + process_model(models.inception_v3(), self.image, _C_tests.forward_inceptionv3, "Inceptionv3") if __name__ == "__main__": diff --git a/test/test_prototype_models.py b/test/test_extended_models.py similarity index 78% rename from test/test_prototype_models.py rename to test/test_extended_models.py index 65b8ffb9e40..4bfe03d1ea0 100644 --- a/test/test_prototype_models.py +++ b/test/test_extended_models.py @@ -3,22 +3,16 @@ import pytest import test_models as TM -import torchvision +from torchvision import models from torchvision.models._api import WeightsEnum, Weights from torchvision.models._utils import handle_legacy_interface run_if_test_with_prototype = pytest.mark.skipif( - os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1", - reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.", + os.getenv("PYTORCH_TEST_WITH_EXTENDED") != "1", + reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.", ) -def _get_original_model(model_fn): - original_module_name = model_fn.__module__.replace(".prototype", "") - module = importlib.import_module(original_module_name) - return module.__dict__[model_fn.__name__] - - def _get_parent_module(model_fn): parent_module_name = ".".join(model_fn.__module__.split(".")[:-1]) module = importlib.import_module(parent_module_name) @@ -38,44 +32,33 @@ def _get_model_weights(model_fn): return None -def _build_model(fn, **kwargs): - try: - model = fn(**kwargs) - except ValueError as e: - msg = str(e) - if "No checkpoint is available" in msg: - pytest.skip(msg) - raise e - return model.eval() - - @pytest.mark.parametrize( "name, weight", [ - ("ResNet50_Weights.IMAGENET1K_V1", torchvision.models.ResNet50_Weights.IMAGENET1K_V1), - ("ResNet50_Weights.DEFAULT", torchvision.models.ResNet50_Weights.IMAGENET1K_V2), + ("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1), + ("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2), ( "ResNet50_QuantizedWeights.DEFAULT", - torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2, + models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2, ), ( "ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1", - torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1, + models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1, ), ], ) def test_get_weight(name, weight): - assert torchvision.models.get_weight(name) == weight + assert models.get_weight(name) == weight @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(torchvision.models) - + TM.get_models_from_module(torchvision.models.detection) - + TM.get_models_from_module(torchvision.models.quantization) - + TM.get_models_from_module(torchvision.models.segmentation) - + TM.get_models_from_module(torchvision.models.video) - + TM.get_models_from_module(torchvision.models.optical_flow), + TM.get_models_from_module(models) + + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(models.segmentation) + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), ) def test_naming_conventions(model_fn): weights_enum = _get_model_weights(model_fn) @@ -86,12 +69,12 @@ def test_naming_conventions(model_fn): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(torchvision.models) - + TM.get_models_from_module(torchvision.models.detection) - + TM.get_models_from_module(torchvision.models.quantization) - + TM.get_models_from_module(torchvision.models.segmentation) - + TM.get_models_from_module(torchvision.models.video) - + TM.get_models_from_module(torchvision.models.optical_flow), + TM.get_models_from_module(models) + + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(models.segmentation) + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), ) @run_if_test_with_prototype def test_schema_meta_validation(model_fn): diff --git a/test/test_hub.py b/test/test_hub.py index 5c791bf9d7a..d88c6fa2cd2 100644 --- a/test/test_hub.py +++ b/test/test_hub.py @@ -26,13 +26,13 @@ class TestHub: # Python cache as we run all hub tests in the same python process. def test_load_from_github(self): - hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) + hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) def test_set_dir(self): temp_dir = tempfile.gettempdir() hub.set_dir(temp_dir) - hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) + hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) assert os.path.exists(temp_dir + "/pytorch_vision_master") shutil.rmtree(temp_dir + "/pytorch_vision_master") diff --git a/test/test_models.py b/test/test_models.py index fb024c8da3f..4f24129753f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -428,7 +428,6 @@ def get_gn(num_channels): def test_inception_v3_eval(): - # replacement for models.inception_v3(pretrained=True) that does not download weights kwargs = {} kwargs["transform_input"] = True kwargs["aux_logits"] = True @@ -444,7 +443,7 @@ def test_inception_v3_eval(): def test_fasterrcnn_double(): - model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50) model.double() model.eval() input_shape = (3, 300, 300) @@ -460,7 +459,6 @@ def test_fasterrcnn_double(): def test_googlenet_eval(): - # replacement for models.googlenet(pretrained=True) that does not download weights kwargs = {} kwargs["transform_input"] = True kwargs["aux_logits"] = True @@ -484,7 +482,7 @@ def checkOut(out): assert "scores" in out[0] assert "labels" in out[0] - model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False) + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50) model.cuda() model.eval() input_shape = (3, 300, 300) @@ -600,7 +598,6 @@ def test_segmentation_model(model_fn, dev): set_rng_seed(0) defaults = { "num_classes": 10, - "pretrained_backbone": False, "input_shape": (1, 3, 32, 32), } model_name = model_fn.__name__ @@ -662,7 +659,6 @@ def test_detection_model(model_fn, dev): set_rng_seed(0) defaults = { "num_classes": 50, - "pretrained_backbone": False, "input_shape": (3, 300, 300), } model_name = model_fn.__name__ @@ -757,7 +753,7 @@ def compute_mean_std(tensor): @pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) def test_detection_model_validation(model_fn): set_rng_seed(0) - model = model_fn(num_classes=50, pretrained_backbone=False) + model = model_fn(num_classes=50) input_shape = (3, 300, 300) x = [torch.rand(input_shape)] @@ -821,7 +817,6 @@ def test_quantized_classification_model(model_fn): defaults = { "num_classes": 5, "input_shape": (1, 3, 224, 224), - "pretrained": False, "quantize": True, } model_name = model_fn.__name__ @@ -871,7 +866,7 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load max_trainable = _model_tests_values[model_name]["max_trainable"] n_trainable_params = [] for trainable_layers in range(0, max_trainable + 1): - model = model_fn(pretrained=False, pretrained_backbone=True, trainable_backbone_layers=trainable_layers) + model = model_fn(weights_backbone="DEFAULT", trainable_backbone_layers=trainable_layers) n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index 7d2953f7e64..3746c2f7920 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -99,9 +99,7 @@ def test_assign_targets_to_proposals(self): ], ) def test_forward_negative_sample_frcnn(self, name): - model = torchvision.models.detection.__dict__[name]( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False - ) + model = torchvision.models.detection.__dict__[name](num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample() loss_dict = model(images, targets) @@ -110,9 +108,7 @@ def test_forward_negative_sample_frcnn(self, name): assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0)) def test_forward_negative_sample_mrcnn(self): - model = torchvision.models.detection.maskrcnn_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False - ) + model = torchvision.models.detection.maskrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample(add_masks=True) loss_dict = model(images, targets) @@ -122,9 +118,7 @@ def test_forward_negative_sample_mrcnn(self): assert_equal(loss_dict["loss_mask"], torch.tensor(0.0)) def test_forward_negative_sample_krcnn(self): - model = torchvision.models.detection.keypointrcnn_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False - ) + model = torchvision.models.detection.keypointrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample(add_keypoints=True) loss_dict = model(images, targets) @@ -134,9 +128,7 @@ def test_forward_negative_sample_krcnn(self): assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.0)) def test_forward_negative_sample_retinanet(self): - model = torchvision.models.detection.retinanet_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False - ) + model = torchvision.models.detection.retinanet_resnet50_fpn(num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample() loss_dict = model(images, targets) @@ -144,9 +136,7 @@ def test_forward_negative_sample_retinanet(self): assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0)) def test_forward_negative_sample_fcos(self): - model = torchvision.models.detection.fcos_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False - ) + model = torchvision.models.detection.fcos_resnet50_fpn(num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample() loss_dict = model(images, targets) @@ -155,7 +145,7 @@ def test_forward_negative_sample_fcos(self): assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0)) def test_forward_negative_sample_ssd(self): - model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False) + model = torchvision.models.detection.ssd300_vgg16(num_classes=2) images, targets = self._make_empty_sample() loss_dict = model(images, targets) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index 44abfd51a7f..5cfc7e04d3f 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -40,7 +40,7 @@ def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): # be frozen for each trainable_backbone_layers parameter value # i.e all 53 params are frozen if trainable_backbone_layers=0 # ad first 24 params are frozen if trainable_backbone_layers=2 - model = backbone_utils.resnet_fpn_backbone("resnet50", pretrained=False, trainable_layers=train_layers) + model = backbone_utils.resnet_fpn_backbone("resnet50", trainable_layers=train_layers) # boolean list that is true if the param at that index is frozen is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] # check that expected initial number of layers are frozen @@ -49,18 +49,18 @@ def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): def test_validate_resnet_inputs_detection(self): # default number of backbone layers to train ret = backbone_utils._validate_trainable_layers( - pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3 + is_trained=True, trainable_backbone_layers=None, max_value=5, default_value=3 ) assert ret == 3 # can't go beyond 5 with pytest.raises(ValueError, match=r"Trainable backbone layers should be in the range"): ret = backbone_utils._validate_trainable_layers( - pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3 + is_trained=True, trainable_backbone_layers=6, max_value=5, default_value=3 ) - # if not pretrained, should use all trainable layers and warn + # if not trained, should use all trainable layers and warn with pytest.warns(UserWarning): ret = backbone_utils._validate_trainable_layers( - pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3 + is_trained=False, trainable_backbone_layers=0, max_value=5, default_value=3 ) assert ret == 5 diff --git a/test/test_onnx.py b/test/test_onnx.py index b725cdf2b90..375d0fd1c6f 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -430,7 +430,9 @@ def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: def test_faster_rcnn(self): images, test_images = self.get_test_images() dummy_image = [torch.ones(3, 100, 100) * 0.3] - model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) + model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn( + weights=models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300 + ) model.eval() model(images) # Test exported model on images of different size, or dummy input @@ -486,7 +488,9 @@ def test_paste_mask_in_image(self): def test_mask_rcnn(self): images, test_images = self.get_test_images() dummy_image = [torch.ones(3, 100, 100) * 0.3] - model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) + model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn( + weights=models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300 + ) model.eval() model(images) # Test exported model on images of different size, or dummy input @@ -548,7 +552,9 @@ def test_heatmaps_to_keypoints(self): def test_keypoint_rcnn(self): images, test_images = self.get_test_images() dummy_images = [torch.ones(3, 100, 100) * 0.3] - model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) + model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn( + weights=models.detection.keypoint_rcnn.KeypointRCNN_ResNet50_FPN_Weights.DEFAULT, min_size=200, max_size=300 + ) model.eval() model(images) self.run_model( @@ -570,7 +576,7 @@ def test_keypoint_rcnn(self): ) def test_shufflenet_v2_dynamic_axes(self): - model = models.shufflenet_v2_x0_5(pretrained=True) + model = models.shufflenet_v2_x0_5(weights=models.ShuffleNet_V2_X0_5_Weights.DEFAULT) dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index b767756692b..20b78254fcc 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -62,7 +62,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: def resnet_fpn_backbone( *, backbone_name: str, - weights: Optional[WeightsEnum], + weights: Optional[WeightsEnum] = None, norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 3, returned_layers: Optional[List[int]] = None, @@ -171,8 +171,8 @@ def _validate_trainable_layers( def mobilenet_backbone( *, backbone_name: str, - weights: Optional[WeightsEnum], - fpn: bool, + weights: Optional[WeightsEnum] = None, + fpn: bool = True, norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 2, returned_layers: Optional[List[int]] = None, From db0fd2730f747f60264216138e1e8ac9940b709b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 16 Mar 2022 12:58:43 +0000 Subject: [PATCH 34/45] Update CI on Multiweight branch to use the new weight download approach (#5628) * port Pad to prototype transforms (#5621) * port Pad to prototype transforms * use literal * Bump up LibTorchvision version number for Podspec to release Cocoapods (#5624) Co-authored-by: Anton Thomma Co-authored-by: Vasilis Vryniotis * pre-download model weights in CI docs build (#5625) * pre-download model weights in CI docs build * move changes into template * change docs image * Regenerated config.yml Co-authored-by: Philip Meier Co-authored-by: Anton Thomma <11010310+thommaa@users.noreply.github.com> Co-authored-by: Anton Thomma --- .circleci/config.yml | 32 +++++-- .circleci/config.yml.in | 32 +++++-- ios/LibTorchvision.podspec | 4 +- scripts/collect_model_urls.py | 18 ++-- test/test_prototype_transforms.py | 1 + torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/_geometry.py | 91 +++++++++++++------ 7 files changed, 120 insertions(+), 59 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 254e758ade2..86aa08d65b6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -174,6 +174,26 @@ commands: - store_test_results: path: test-results + download_model_weights: + parameters: + extract_roots: + type: string + default: "torchvision/models" + background: + type: boolean + default: true + steps: + - apt_install: + args: parallel wget + descr: Install download utilitites + - run: + name: Download model weights + background: << parameters.background >> + command: | + mkdir -p ~/.cache/torch/hub/checkpoints + python scripts/collect_model_urls.py << parameters.extract_roots >> \ + | parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci' + binary_common: &binary_common parameters: # Edit these defaults to do a release @@ -354,14 +374,7 @@ jobs: resource_class: xlarge steps: - checkout - - run: - name: Download model weights - background: true - command: | - sudo apt update -qy && sudo apt install -qy parallel wget - mkdir -p ~/.cache/torch/hub/checkpoints - python scripts/collect_model_urls.py torchvision/models \ - | parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci' + - download_model_weights - install_torchvision - run: name: Enable extended tests @@ -1021,12 +1034,13 @@ jobs: build_docs: <<: *binary_common docker: - - image: "pytorch/manylinux-cuda100" + - image: circleci/python:3.7 resource_class: 2xlarge+ steps: - attach_workspace: at: ~/workspace - checkout + - download_model_weights - run: name: Setup command: .circleci/unittest/linux/scripts/setup_env.sh diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 6b5b31a3e6b..df4ad5cb310 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -174,6 +174,26 @@ commands: - store_test_results: path: test-results + download_model_weights: + parameters: + extract_roots: + type: string + default: "torchvision/models" + background: + type: boolean + default: true + steps: + - apt_install: + args: parallel wget + descr: Install download utilitites + - run: + name: Download model weights + background: << parameters.background >> + command: | + mkdir -p ~/.cache/torch/hub/checkpoints + python scripts/collect_model_urls.py << parameters.extract_roots >> \ + | parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci' + binary_common: &binary_common parameters: # Edit these defaults to do a release @@ -354,14 +374,7 @@ jobs: resource_class: xlarge steps: - checkout - - run: - name: Download model weights - background: true - command: | - sudo apt update -qy && sudo apt install -qy parallel wget - mkdir -p ~/.cache/torch/hub/checkpoints - python scripts/collect_model_urls.py torchvision/models \ - | parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci' + - download_model_weights - install_torchvision - run: name: Enable extended tests @@ -1021,12 +1034,13 @@ jobs: build_docs: <<: *binary_common docker: - - image: "pytorch/manylinux-cuda100" + - image: circleci/python:3.7 resource_class: 2xlarge+ steps: - attach_workspace: at: ~/workspace - checkout + - download_model_weights - run: name: Setup command: .circleci/unittest/linux/scripts/setup_env.sh diff --git a/ios/LibTorchvision.podspec b/ios/LibTorchvision.podspec index 2260bc5fcfc..3e760b77edd 100644 --- a/ios/LibTorchvision.podspec +++ b/ios/LibTorchvision.podspec @@ -1,8 +1,8 @@ -pytorch_version = '1.10.0' +pytorch_version = '1.11.0' Pod::Spec.new do |s| s.name = 'LibTorchvision' - s.version = '0.11.1' + s.version = '0.12.0' s.authors = 'PyTorch Team' s.license = { :type => 'BSD' } s.homepage = 'https://github.com/pytorch/vision' diff --git a/scripts/collect_model_urls.py b/scripts/collect_model_urls.py index 3554e80b1ed..2acba6cbbda 100644 --- a/scripts/collect_model_urls.py +++ b/scripts/collect_model_urls.py @@ -2,21 +2,19 @@ import re import sys -MODEL_URL_PATTERN = re.compile(r"https://download[.]pytorch[.]org/models/.*?[.]pth") +MODEL_URL_PATTERN = re.compile(r"https://download[.]pytorch[.]org/models/.+?[.]pth") -def main(root): +def main(*roots): model_urls = set() - for path in pathlib.Path(root).glob("**/*"): - if path.name.startswith("_") or not path.suffix == ".py": - continue - - with open(path, "r") as file: - for line in file: - model_urls.update(MODEL_URL_PATTERN.findall(line)) + for root in roots: + for path in pathlib.Path(root).rglob("*.py"): + with open(path, "r") as file: + for line in file: + model_urls.update(MODEL_URL_PATTERN.findall(line)) print("\n".join(sorted(model_urls))) if __name__ == "__main__": - main(sys.argv[1]) + main(*sys.argv[1:]) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 5b0693a2e78..b6085bb1c71 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -71,6 +71,7 @@ class TestSmoke: transforms.CenterCrop([16, 16]), transforms.ConvertImageDtype(), transforms.RandomHorizontalFlip(), + transforms.Pad(5), ) def test_common(self, transform, input): transform(input) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 48581a23930..be192fa3f5d 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -13,6 +13,7 @@ TenCrop, BatchMultiCrop, RandomHorizontalFlip, + Pad, RandomZoomOut, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 3276a8b2ab2..6a05a8895f7 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,5 +1,6 @@ import collections.abc import math +import numbers import warnings from typing import Any, Dict, List, Union, Sequence, Tuple, cast @@ -9,6 +10,7 @@ from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms.functional import pil_to_tensor, InterpolationMode from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int +from typing_extensions import Literal from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor @@ -272,42 +274,31 @@ def apply_recursively(obj: Any) -> Any: return apply_recursively(inputs if len(inputs) > 1 else inputs[0]) -class RandomZoomOut(Transform): +class Pad(Transform): def __init__( - self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 + self, + padding: Union[int, Sequence[int]], + fill: Union[float, Sequence[float]] = 0.0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") - if fill is None: - fill = 0.0 - self.fill = fill - - self.side_range = side_range - if side_range[0] < 1.0 or side_range[0] > side_range[1]: - raise ValueError(f"Invalid canvas side range provided {side_range}.") - - self.p = p - - def _get_params(self, sample: Any) -> Dict[str, Any]: - image = query_image(sample) - orig_c, orig_h, orig_w = get_image_dimensions(image) - - r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) - canvas_width = int(orig_w * r) - canvas_height = int(orig_h * r) + if not isinstance(fill, (numbers.Number, str, tuple, list)): + raise TypeError("Got inappropriate fill arg") - r = torch.rand(2) - left = int((canvas_width - orig_w) * r[0]) - top = int((canvas_height - orig_h) * r[1]) - right = canvas_width - (left + orig_w) - bottom = canvas_height - (top + orig_h) - padding = [left, top, right, bottom] + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") - fill = self.fill - if not isinstance(fill, collections.abc.Sequence): - fill = [fill] * orig_c + if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: + raise ValueError( + f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" + ) - return dict(padding=padding, fill=fill) + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, features.Image) or is_simple_tensor(input): @@ -349,6 +340,48 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input + +class RandomZoomOut(Transform): + def __init__( + self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 + ) -> None: + super().__init__() + + if fill is None: + fill = 0.0 + self.fill = fill + + self.side_range = side_range + if side_range[0] < 1.0 or side_range[0] > side_range[1]: + raise ValueError(f"Invalid canvas side range provided {side_range}.") + + self.p = p + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + orig_c, orig_h, orig_w = get_image_dimensions(image) + + r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) + canvas_width = int(orig_w * r) + canvas_height = int(orig_h * r) + + r = torch.rand(2) + left = int((canvas_width - orig_w) * r[0]) + top = int((canvas_height - orig_h) * r[1]) + right = canvas_width - (left + orig_w) + bottom = canvas_height - (top + orig_h) + padding = [left, top, right, bottom] + + fill = self.fill + if not isinstance(fill, collections.abc.Sequence): + fill = [fill] * orig_c + + return dict(padding=padding, fill=fill) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + transform = Pad(**params, padding_mode="constant") + return transform(input) + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if torch.rand(1) >= self.p: From 0a612cb2011d9e271c4cf92a0bfef415892fb018 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 17 Mar 2022 13:13:51 +0000 Subject: [PATCH 35/45] Porting reference scripts and updating presets (#5629) * Making _preset.py classes * Remove support of targets on presets. * Rewriting the video preset * Adding tests to check that the bundled transforms are JIT scriptable * Rename all presets from *Eval to *Inference * Minor refactoring * Remove --prototype and --pretrained from reference scripts * remove pretained_backbone refs * Corrections and simplifications * Fixing bug * Fixing linter * Fix flake8 * restore documentation example * minor fixes * fix optical flow missing param * Fixing commands * Adding weights_backbone support in detection and segmentation * Updating the commands for InceptionV3 --- docs/source/models.rst | 54 +---------- gallery/plot_optical_flow.py | 2 +- gallery/plot_repurposing_annotations.py | 2 +- gallery/plot_visualization_utils.py | 8 +- references/classification/README.md | 26 ++---- references/classification/train.py | 42 +-------- .../classification/train_quantization.py | 23 +---- references/classification/utils.py | 8 +- references/detection/README.md | 18 ++-- references/detection/train.py | 46 ++-------- references/optical_flow/README.md | 4 +- references/optical_flow/train.py | 37 ++------ references/segmentation/README.md | 12 +-- references/segmentation/train.py | 49 ++-------- references/video_classification/presets.py | 4 +- references/video_classification/train.py | 40 ++------ test/test_extended_models.py | 65 ++++++++++++- test/test_models.py | 3 +- torchvision/models/alexnet.py | 4 +- torchvision/models/convnext.py | 10 +- torchvision/models/densenet.py | 10 +- torchvision/models/detection/faster_rcnn.py | 8 +- torchvision/models/detection/fcos.py | 4 +- torchvision/models/detection/keypoint_rcnn.py | 6 +- torchvision/models/detection/mask_rcnn.py | 4 +- torchvision/models/detection/retinanet.py | 4 +- torchvision/models/detection/ssd.py | 4 +- torchvision/models/detection/ssdlite.py | 4 +- torchvision/models/efficientnet.py | 26 +++--- torchvision/models/googlenet.py | 4 +- torchvision/models/inception.py | 4 +- torchvision/models/mnasnet.py | 6 +- torchvision/models/mobilenetv2.py | 6 +- torchvision/models/mobilenetv3.py | 8 +- torchvision/models/optical_flow/raft.py | 18 ++-- torchvision/models/quantization/googlenet.py | 4 +- torchvision/models/quantization/inception.py | 4 +- .../models/quantization/mobilenetv2.py | 4 +- .../models/quantization/mobilenetv3.py | 4 +- torchvision/models/quantization/resnet.py | 12 +-- .../models/quantization/shufflenetv2.py | 6 +- torchvision/models/regnet.py | 58 ++++++------ torchvision/models/resnet.py | 34 +++---- torchvision/models/segmentation/deeplabv3.py | 8 +- torchvision/models/segmentation/fcn.py | 6 +- torchvision/models/segmentation/lraspp.py | 4 +- torchvision/models/shufflenetv2.py | 6 +- torchvision/models/squeezenet.py | 6 +- torchvision/models/vgg.py | 20 ++-- torchvision/models/video/resnet.py | 8 +- torchvision/models/vision_transformer.py | 10 +- torchvision/transforms/__init__.py | 7 -- torchvision/transforms/_presets.py | 91 +++++++++---------- 53 files changed, 343 insertions(+), 522 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 50af05360e4..39543cb8027 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -98,58 +98,6 @@ You can construct a model with random weights by calling its constructor: convnext_large = models.convnext_large() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. -These can be constructed by passing ``pretrained=True``: - -.. code:: python - - import torchvision.models as models - resnet18 = models.resnet18(pretrained=True) - alexnet = models.alexnet(pretrained=True) - squeezenet = models.squeezenet1_0(pretrained=True) - vgg16 = models.vgg16(pretrained=True) - densenet = models.densenet161(pretrained=True) - inception = models.inception_v3(pretrained=True) - googlenet = models.googlenet(pretrained=True) - shufflenet = models.shufflenet_v2_x1_0(pretrained=True) - mobilenet_v2 = models.mobilenet_v2(pretrained=True) - mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True) - mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True) - resnext50_32x4d = models.resnext50_32x4d(pretrained=True) - wide_resnet50_2 = models.wide_resnet50_2(pretrained=True) - mnasnet = models.mnasnet1_0(pretrained=True) - efficientnet_b0 = models.efficientnet_b0(pretrained=True) - efficientnet_b1 = models.efficientnet_b1(pretrained=True) - efficientnet_b2 = models.efficientnet_b2(pretrained=True) - efficientnet_b3 = models.efficientnet_b3(pretrained=True) - efficientnet_b4 = models.efficientnet_b4(pretrained=True) - efficientnet_b5 = models.efficientnet_b5(pretrained=True) - efficientnet_b6 = models.efficientnet_b6(pretrained=True) - efficientnet_b7 = models.efficientnet_b7(pretrained=True) - efficientnet_v2_s = models.efficientnet_v2_s(pretrained=True) - efficientnet_v2_m = models.efficientnet_v2_m(pretrained=True) - efficientnet_v2_l = models.efficientnet_v2_l(pretrained=True) - regnet_y_400mf = models.regnet_y_400mf(pretrained=True) - regnet_y_800mf = models.regnet_y_800mf(pretrained=True) - regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True) - regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True) - regnet_y_8gf = models.regnet_y_8gf(pretrained=True) - regnet_y_16gf = models.regnet_y_16gf(pretrained=True) - regnet_y_32gf = models.regnet_y_32gf(pretrained=True) - regnet_x_400mf = models.regnet_x_400mf(pretrained=True) - regnet_x_800mf = models.regnet_x_800mf(pretrained=True) - regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True) - regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True) - regnet_x_8gf = models.regnet_x_8gf(pretrained=True) - regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue) - regnet_x_32gf = models.regnet_x_32gf(pretrained=True) - vit_b_16 = models.vit_b_16(pretrained=True) - vit_b_32 = models.vit_b_32(pretrained=True) - vit_l_16 = models.vit_l_16(pretrained=True) - vit_l_32 = models.vit_l_32(pretrained=True) - convnext_tiny = models.convnext_tiny(pretrained=True) - convnext_small = models.convnext_small(pretrained=True) - convnext_base = models.convnext_base(pretrained=True) - convnext_large = models.convnext_large(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_HOME` environment variable. See @@ -525,7 +473,7 @@ Obtaining a pre-trained quantized model can be done with a few lines of code: .. code:: python import torchvision.models as models - model = models.quantization.mobilenet_v2(pretrained=True, quantize=True) + model = models.quantization.mobilenet_v2(weights=MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1, quantize=True) model.eval() # run the model with quantized inputs and weights out = model(torch.rand(1, 3, 224, 224)) diff --git a/gallery/plot_optical_flow.py b/gallery/plot_optical_flow.py index 770610fb971..5149ebc541b 100644 --- a/gallery/plot_optical_flow.py +++ b/gallery/plot_optical_flow.py @@ -96,7 +96,7 @@ def plot(imgs, **imshow_kwargs): def preprocess(img1_batch, img2_batch): img1_batch = F.resize(img1_batch, size=[520, 960]) img2_batch = F.resize(img2_batch, size=[520, 960]) - return transforms(img1_batch, img2_batch)[:2] + return transforms(img1_batch, img2_batch) img1_batch, img2_batch = preprocess(img1_batch, img2_batch) diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py index a826a2523f2..7bb68617a17 100644 --- a/gallery/plot_repurposing_annotations.py +++ b/gallery/plot_repurposing_annotations.py @@ -146,7 +146,7 @@ def show(imgs): print(img.size()) tranforms = weights.transforms() -img, _ = tranforms(img) +img = tranforms(img) target = {} target["boxes"] = boxes target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64) diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 27fd97681c0..7f92d54ebdd 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -81,7 +81,7 @@ def show(imgs): weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() -batch, _ = transforms(batch_int) +batch = transforms(batch_int) model = fasterrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() @@ -131,7 +131,7 @@ def show(imgs): model = fcn_resnet50(weights=weights, progress=False) model = model.eval() -normalized_batch, _ = transforms(batch) +normalized_batch = transforms(batch) output = model(normalized_batch)['out'] print(output.shape, output.min().item(), output.max().item()) @@ -272,7 +272,7 @@ def show(imgs): weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() -batch, _ = transforms(batch_int) +batch = transforms(batch_int) model = maskrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() @@ -397,7 +397,7 @@ def show(imgs): weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() -person_float, _ = transforms(person_int) +person_float = transforms(person_int) model = keypointrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() diff --git a/references/classification/README.md b/references/classification/README.md index 173fb454995..c274c997791 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -43,7 +43,7 @@ Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model ``` torchrun --nproc_per_node=8 train.py --model inception_v3\ - --val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained + --test-only --weights Inception_V3_Weights.IMAGENET1K_V1 ``` ### ResNet @@ -96,22 +96,14 @@ The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTo All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands: ``` -torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\ - --val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\ - --val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\ - --val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\ - --val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\ - --val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\ - --val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\ - --val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained -torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\ - --val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained +torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --test-only --weights EfficientNet_B0_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --test-only --weights EfficientNet_B1_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --test-only --weights EfficientNet_B2_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --test-only --weights EfficientNet_B3_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --test-only --weights EfficientNet_B4_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --test-only --weights EfficientNet_B5_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --test-only --weights EfficientNet_B6_Weights.IMAGENET1K_V1 +torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --test-only --weights EfficientNet_B7_Weights.IMAGENET1K_V1 ``` diff --git a/references/classification/train.py b/references/classification/train.py index 569cf3009e7..eb8b56c1ad0 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -15,12 +15,6 @@ from torchvision.transforms.functional import InterpolationMode -try: - from torchvision import prototype -except ImportError: - prototype = None - - def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") @@ -154,18 +148,13 @@ def load_data(traindir, valdir, args): print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: - if not args.prototype: + if args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + preprocessing = weights.transforms() + else: preprocessing = presets.ClassificationPresetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) - else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - preprocessing = weights.transforms() - else: - preprocessing = prototype.transforms.ImageClassificationEval( - crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation - ) dataset_test = torchvision.datasets.ImageFolder( valdir, @@ -191,10 +180,6 @@ def load_data(traindir, valdir, args): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -236,10 +221,7 @@ def main(args): ) print("Creating model") - if not args.prototype: - model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) - else: - model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) + model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) model.to(device) if args.distributed and args.sync_bn: @@ -446,12 +428,6 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") @@ -496,14 +472,6 @@ def get_args_parser(add_help=True): parser.add_argument( "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" ) - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index 111777a860b..c0e5af1dcfc 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -12,17 +12,7 @@ from train import train_one_epoch, evaluate, load_data -try: - from torchvision import prototype -except ImportError: - prototype = None - - def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -56,10 +46,7 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model - if not args.prototype: - model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) - else: - model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) + model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) model.to(device) if not (args.test_only or args.post_training_quantize): @@ -264,14 +251,6 @@ def get_args_parser(add_help=True): "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 7f573415c4c..27398d97234 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -330,22 +330,22 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T from torchvision import models as M # Classification - model = M.mobilenet_v3_large(pretrained=False) + model = M.mobilenet_v3_large() print(store_model_weights(model, './class.pth')) # Quantized Classification - model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False) + model = M.quantization.mobilenet_v3_large(quantize=False) model.fuse_model(is_qat=True) model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') _ = torch.ao.quantization.prepare_qat(model, inplace=True) print(store_model_weights(model, './qat.pth')) # Object Detection - model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False) + model = M.detection.fasterrcnn_mobilenet_v3_large_fpn() print(store_model_weights(model, './obj.pth')) # Segmentation - model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True) + model = M.segmentation.deeplabv3_mobilenet_v3_large(aux_loss=True) print(store_model_weights(model, './segm.pth', strict=False)) Args: diff --git a/references/detection/README.md b/references/detection/README.md index 3695644138b..aec7c10e1b5 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -24,35 +24,35 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs. ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### Faster R-CNN MobileNetV3-Large FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ### Faster R-CNN MobileNetV3-Large 320 FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ### FCOS ResNet-50 FPN ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model fcos_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### RetinaNet ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model retinanet_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ### SSD300 VGG16 @@ -60,7 +60,7 @@ torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\ --dataset coco --model ssd300_vgg16 --epochs 120\ --lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\ - --weight-decay 0.0005 --data-augmentation ssd + --weight-decay 0.0005 --data-augmentation ssd --weights-backbone VGG16_Weights.IMAGENET1K_FEATURES ``` ### SSDlite320 MobileNetV3-Large @@ -68,7 +68,7 @@ torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\ --dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\ --aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\ - --weight-decay 0.00004 --data-augmentation ssdlite + --weight-decay 0.00004 --data-augmentation ssdlite --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` @@ -76,7 +76,7 @@ torchrun --nproc_per_node=8 train.py\ ``` torchrun --nproc_per_node=8 train.py\ --dataset coco --model maskrcnn_resnet50_fpn --epochs 26\ - --lr-steps 16 22 --aspect-ratio-group-factor 3 + --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` @@ -84,5 +84,5 @@ torchrun --nproc_per_node=8 train.py\ ``` torchrun --nproc_per_node=8 train.py\ --dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\ - --lr-steps 36 43 --aspect-ratio-group-factor 3 + --lr-steps 36 43 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` diff --git a/references/detection/train.py b/references/detection/train.py index 3909e6413d0..0e0a0d70fad 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -33,12 +33,6 @@ from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups -try: - from torchvision import prototype -except ImportError: - prototype = None - - def get_dataset(name, image_set, transform, data_path): paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} p, ds_fn, num_classes = paths[name] @@ -50,14 +44,12 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: return presets.DetectionPresetTrain(args.data_augmentation) - elif not args.prototype: - return presets.DetectionPresetEval() + elif args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + return lambda img, target=None: (trans(img), target) else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - return weights.transforms() - else: - return prototype.transforms.ObjectDetectionEval() + return presets.DetectionPresetEval() def get_args_parser(add_help=True): @@ -132,25 +124,12 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load") # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") @@ -159,10 +138,6 @@ def get_args_parser(add_help=True): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -204,12 +179,9 @@ def main(args): if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - if not args.prototype: - model = torchvision.models.detection.__dict__[args.model]( - pretrained=args.pretrained, num_classes=num_classes, **kwargs - ) - else: - model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) + model = torchvision.models.detection.__dict__[args.model]( + weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs + ) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) diff --git a/references/optical_flow/README.md b/references/optical_flow/README.md index a7620ce4be6..a7ac0223739 100644 --- a/references/optical_flow/README.md +++ b/references/optical_flow/README.md @@ -51,7 +51,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \ ### Evaluation ``` -torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset sintel --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2 ``` This should give an epe of about 1.3822 on the clean pass and 2.7161 on the @@ -67,6 +67,6 @@ Sintel val final epe: 2.7161 1px: 0.8528 3px: 0.9204 5px: 0.9392 per_image_epe: You can also evaluate on Kitti train: ``` -torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --pretrained +torchrun --nproc_per_node 1 --nnodes 1 train.py --val-dataset kitti --batch-size 1 --dataset-root $dataset_root --model raft_large --weights Raft_Large_Weights.C_T_SKHT_V2 Kitti val epe: 4.7968 1px: 0.6388 3px: 0.8197 5px: 0.8661 per_image_epe: 4.5118 f1: 16.0679 ``` diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 83952242eb9..1a50d1c617d 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -9,11 +9,6 @@ from presets import OpticalFlowPresetTrain, OpticalFlowPresetEval from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K -try: - from torchvision import prototype -except ImportError: - prototype = None - def get_train_dataset(stage, dataset_root): if stage == "chairs": @@ -138,12 +133,10 @@ def inner_loop(blob): def evaluate(model, args): val_datasets = args.val_dataset or [] - if args.prototype: - if args.weights: - weights = prototype.models.get_weight(args.weights) - preprocessing = weights.transforms() - else: - preprocessing = prototype.transforms.OpticalFlowEval() + if args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + preprocessing = lambda img1, img2, flow=None, valid=None: trans(img1, img2) + (flow, valid) # noqa: E731 else: preprocessing = OpticalFlowPresetEval() @@ -201,20 +194,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") utils.setup_ddp(args) + args.test_only = args.train_dataset is None if args.distributed and args.device == "cpu": raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun") device = torch.device(args.device) - if args.prototype: - model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights) - else: - model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) + model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights) if args.distributed: model = model.to(args.local_rank) @@ -228,7 +215,7 @@ def main(args): checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) - if args.train_dataset is None: + if args.test_only: # Set deterministic CUDNN algorithms, since they can affect epe a fair bit. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True @@ -356,8 +343,7 @@ def get_args_parser(add_help=True): parser.add_argument( "--model", type=str, default="raft_large", help="The name of the model to use - either raft_large or raft_small" ) - # TODO: resume, pretrained, and weights should be in an exclusive arg group - parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights") + # TODO: resume and weights should be in an exclusive arg group parser.add_argument( "--num_flow_updates", @@ -376,13 +362,6 @@ def get_args_parser(add_help=True): required=True, ) - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)") diff --git a/references/segmentation/README.md b/references/segmentation/README.md index e9b5391215a..2c7391c8380 100644 --- a/references/segmentation/README.md +++ b/references/segmentation/README.md @@ -14,30 +14,30 @@ You must modify the following flags: ## fcn_resnet50 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ## fcn_resnet101 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model fcn_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 ``` ## deeplabv3_resnet50 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet50 --aux-loss --weights-backbone ResNet50_Weights.IMAGENET1K_V1 ``` ## deeplabv3_resnet101 ``` -torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss +torchrun --nproc_per_node=8 train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss --weights-backbone ResNet101_Weights.IMAGENET1K_V1 ``` ## deeplabv3_mobilenet_v3_large ``` -torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 +torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` ## lraspp_mobilenet_v3_large ``` -torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 +torchrun --nproc_per_node=8 train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1 ``` diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 5dc03945bd7..b4e55acd407 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -11,12 +11,6 @@ from torch import nn -try: - from torchvision import prototype -except ImportError: - prototype = None - - def get_dataset(dir_path, name, image_set, transform): def sbd(*args, **kwargs): return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) @@ -35,14 +29,12 @@ def sbd(*args, **kwargs): def get_transform(train, args): if train: return presets.SegmentationPresetTrain(base_size=520, crop_size=480) - elif not args.prototype: - return presets.SegmentationPresetEval(base_size=520) + elif args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + trans = weights.transforms() + return lambda img, target=None: (trans(img), target) else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - return weights.transforms() - else: - return prototype.transforms.SemanticSegmentationEval(resize_size=520) + return presets.SegmentationPresetEval(base_size=520) def criterion(inputs, target): @@ -100,10 +92,6 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -135,16 +123,9 @@ def main(args): dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn ) - if not args.prototype: - model = torchvision.models.segmentation.__dict__[args.model]( - pretrained=args.pretrained, - num_classes=num_classes, - aux_loss=args.aux_loss, - ) - else: - model = prototype.models.segmentation.__dict__[args.model]( - weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss - ) + model = torchvision.models.segmentation.__dict__[args.model]( + weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss + ) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -272,24 +253,12 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load") # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index 04039c9a4f1..d24169e42dd 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -6,8 +6,8 @@ class VideoClassificationPresetTrain: def __init__( self, - resize_size, crop_size, + resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989), hflip_prob=0.5, @@ -27,7 +27,7 @@ def __call__(self, x): class VideoClassificationPresetEval: - def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): + def __init__(self, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): self.transforms = transforms.Compose( [ ConvertBHWCtoBCHW(), diff --git a/references/video_classification/train.py b/references/video_classification/train.py index d36785ddf96..da7ef9fc607 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,11 +12,6 @@ from torch.utils.data.dataloader import default_collate from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler -try: - from torchvision import prototype -except ImportError: - prototype = None - def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): model.train() @@ -96,10 +91,6 @@ def collate_fn(batch): def main(args): - if args.prototype and prototype is None: - raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if not args.prototype and args.weights: - raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -120,7 +111,7 @@ def main(args): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112)) + transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_train from {cache_path}") @@ -150,14 +141,11 @@ def main(args): print("Loading validation data") cache_path = _get_cache_path(valdir) - if not args.prototype: - transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112)) + if args.weights and args.test_only: + weights = torchvision.models.get_weight(args.weights) + transform_test = weights.transforms() else: - if args.weights: - weights = prototype.models.get_weight(args.weights) - transform_test = weights.transforms() - else: - transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171)) + transform_test = presets.VideoClassificationPresetEval(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") @@ -208,10 +196,7 @@ def main(args): ) print("Creating model") - if not args.prototype: - model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) - else: - model = prototype.models.video.__dict__[args.model](weights=args.weights) + model = torchvision.models.video.__dict__[args.model](weights=args.weights) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -352,24 +337,11 @@ def parse_args(): help="Only test the model", action="store_true", ) - parser.add_argument( - "--pretrained", - dest="pretrained", - help="Use pre-trained models from the modelzoo", - action="store_true", - ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") - # Prototype models only - parser.add_argument( - "--prototype", - dest="prototype", - help="Use prototype model builders instead those from main area", - action="store_true", - ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") # Mixed precision training parameters diff --git a/test/test_extended_models.py b/test/test_extended_models.py index 4bfe03d1ea0..a07b501e15b 100644 --- a/test/test_extended_models.py +++ b/test/test_extended_models.py @@ -3,12 +3,14 @@ import pytest import test_models as TM +import torch from torchvision import models from torchvision.models._api import WeightsEnum, Weights from torchvision.models._utils import handle_legacy_interface -run_if_test_with_prototype = pytest.mark.skipif( - os.getenv("PYTORCH_TEST_WITH_EXTENDED") != "1", + +run_if_test_with_extended = pytest.mark.skipif( + os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1", reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.", ) @@ -76,7 +78,7 @@ def test_naming_conventions(model_fn): + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), ) -@run_if_test_with_prototype +@run_if_test_with_extended def test_schema_meta_validation(model_fn): classification_fields = ["size", "categories", "acc@1", "acc@5", "min_size"] defaults = { @@ -123,6 +125,63 @@ def test_schema_meta_validation(model_fn): assert not bad_names +@pytest.mark.parametrize( + "model_fn", + TM.get_models_from_module(models) + + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(models.segmentation) + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), +) +@run_if_test_with_extended +def test_transforms_jit(model_fn): + model_name = model_fn.__name__ + weights_enum = _get_model_weights(model_fn) + if len(weights_enum) == 0: + pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.") + + defaults = { + "models": { + "input_shape": (1, 3, 224, 224), + }, + "detection": { + "input_shape": (3, 300, 300), + }, + "quantization": { + "input_shape": (1, 3, 224, 224), + }, + "segmentation": { + "input_shape": (1, 3, 520, 520), + }, + "video": { + "input_shape": (1, 4, 112, 112, 3), + }, + "optical_flow": { + "input_shape": (1, 3, 128, 128), + }, + } + module_name = model_fn.__module__.split(".")[-2] + + kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})} + input_shape = kwargs.pop("input_shape") + x = torch.rand(input_shape) + if module_name == "optical_flow": + args = (x, x) + else: + args = (x,) + + problematic_weights = [] + for w in weights_enum: + transforms = w.transforms() + try: + TM._check_jit_scriptable(transforms, args) + except Exception: + problematic_weights.append(w) + + assert not problematic_weights + + # With this filter, every unexpected warning will be turned into an error @pytest.mark.filterwarnings("error") class TestHandleLegacyInterface: diff --git a/test/test_models.py b/test/test_models.py index 5bef9e24d9f..0d45d61df13 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -133,8 +133,7 @@ def get_export_import_copy(m): if eager_out is None: with torch.no_grad(), freeze_rng_state(): - if unwrapper: - eager_out = nn_module(*args) + eager_out = nn_module(*args) with torch.no_grad(), freeze_rng_state(): script_out = sm(*args) diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 4df533000f9..6ee5b98c673 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -55,7 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AlexNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "AlexNet", diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 8d25e77eaa1..8774b9a1bc2 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -7,7 +7,7 @@ from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -218,7 +218,7 @@ def _convnext( class ConvNeXt_Tiny_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236), + transforms=partial(ImageClassification, crop_size=224, resize_size=236), meta={ **_COMMON_META, "num_params": 28589128, @@ -232,7 +232,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_small-0c510722.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230), + transforms=partial(ImageClassification, crop_size=224, resize_size=230), meta={ **_COMMON_META, "num_params": 50223688, @@ -246,7 +246,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 88591464, @@ -260,7 +260,7 @@ class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 197767336, diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index b0de4529902..2ffb29c54cb 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -9,7 +9,7 @@ import torch.utils.checkpoint as cp from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -280,7 +280,7 @@ def _densenet( class DenseNet121_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 7978856, @@ -294,7 +294,7 @@ class DenseNet121_Weights(WeightsEnum): class DenseNet161_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet161-8d451a50.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 28681000, @@ -308,7 +308,7 @@ class DenseNet161_Weights(WeightsEnum): class DenseNet169_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 14149480, @@ -322,7 +322,7 @@ class DenseNet169_Weights(WeightsEnum): class DenseNet201_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/densenet201-c1103571.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 20013928, diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index b5a1df4502c..7d18fbe90a3 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -336,7 +336,7 @@ def forward(self, x): class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 41755286, @@ -350,7 +350,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 19386354, @@ -364,7 +364,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 19386354, diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 7948cd76ab2..27e54a565f2 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -11,7 +11,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -646,7 +646,7 @@ def forward( class FCOS_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "FCOS", diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 522545293a0..2a554a6f56e 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -318,7 +318,7 @@ def forward(self, x): class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_LEGACY = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 59137258, @@ -329,7 +329,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ) COCO_V1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ **_COMMON_META, "num_params": 59137258, diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 47b32984991..fb60ffcbb0a 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -5,7 +5,7 @@ from torchvision.ops import MultiScaleRoIAlign from ...ops import misc as misc_nn_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_value_param @@ -308,7 +308,7 @@ def __init__(self, in_channels, dim_reduced, num_classes): class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "MaskRCNN", diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 048d68b52dc..49b9acf45e4 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -10,7 +10,7 @@ from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -588,7 +588,7 @@ def forward(self, images, targets=None): class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "RetinaNet", diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index b907b3fccf8..c30919e621c 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops import boxes as box_ops -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES @@ -28,7 +28,7 @@ class SSD300_VGG16_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "SSD", diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index dad28cfed13..93023337d11 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ...ops.misc import Conv2dNormActivation -from ...transforms import ObjectDetectionEval, InterpolationMode +from ...transforms._presets import ObjectDetection, InterpolationMode from ...utils import _log_api_usage_once from .. import mobilenet from .._api import WeightsEnum, Weights @@ -187,7 +187,7 @@ def _mobilenet_extractor( class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): COCO_V1 = Weights( url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", - transforms=ObjectDetectionEval, + transforms=ObjectDetection, meta={ "task": "image_object_detection", "architecture": "SSDLite", diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 9665c169bbf..b9d3b9b30c9 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -10,7 +10,7 @@ from torchvision.ops import StochasticDepth from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -458,7 +458,7 @@ class EfficientNet_B0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", transforms=partial( - ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -475,7 +475,7 @@ class EfficientNet_B1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -488,7 +488,7 @@ class EfficientNet_B1_Weights(WeightsEnum): IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", transforms=partial( - ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR + ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR ), meta={ **_COMMON_META_V1, @@ -507,7 +507,7 @@ class EfficientNet_B2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", transforms=partial( - ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -524,7 +524,7 @@ class EfficientNet_B3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", transforms=partial( - ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -541,7 +541,7 @@ class EfficientNet_B4_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", transforms=partial( - ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -558,7 +558,7 @@ class EfficientNet_B5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", transforms=partial( - ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -575,7 +575,7 @@ class EfficientNet_B6_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", transforms=partial( - ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -592,7 +592,7 @@ class EfficientNet_B7_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", transforms=partial( - ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC + ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC ), meta={ **_COMMON_META_V1, @@ -609,7 +609,7 @@ class EfficientNet_V2_S_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", transforms=partial( - ImageClassificationEval, + ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BILINEAR, @@ -629,7 +629,7 @@ class EfficientNet_V2_M_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", transforms=partial( - ImageClassificationEval, + ImageClassification, crop_size=480, resize_size=480, interpolation=InterpolationMode.BILINEAR, @@ -649,7 +649,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", transforms=partial( - ImageClassificationEval, + ImageClassification, crop_size=480, resize_size=480, interpolation=InterpolationMode.BICUBIC, diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index e09e6788097..ced92571974 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -278,7 +278,7 @@ def forward(self, x: Tensor) -> Tensor: class GoogLeNet_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/googlenet-1378be20.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 24d084b62d2..816fab45549 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -410,7 +410,7 @@ def forward(self, x: Tensor) -> Tensor: class Inception_V3_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + transforms=partial(ImageClassification, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 287911edbec..578e77f7934 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -226,7 +226,7 @@ def _load_from_state_dict( class MNASNet0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2218512, @@ -245,7 +245,7 @@ class MNASNet0_75_Weights(WeightsEnum): class MNASNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 4383312, diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 1e19db1a314..085049117ec 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -7,7 +7,7 @@ from torch import nn from ..ops.misc import Conv2dNormActivation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -209,7 +209,7 @@ def forward(self, x: Tensor) -> Tensor: class MobileNet_V2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", @@ -219,7 +219,7 @@ class MobileNet_V2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 3a98456416d..91e1ea91a94 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -317,7 +317,7 @@ def _mobilenet_v3( class MobileNet_V3_Large_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 5483032, @@ -328,7 +328,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5483032, @@ -343,7 +343,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): class MobileNet_V3_Small_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2542856, diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index a506224d4b3..244d2b2fac1 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -8,7 +8,7 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.ops import Conv2dNormActivation -from ...transforms import OpticalFlowEval, InterpolationMode +from ...transforms._presets import OpticalFlow, InterpolationMode from ...utils import _log_api_usage_once from .._api import Weights, WeightsEnum from .._utils import handle_legacy_interface @@ -523,7 +523,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-things.pth) url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -538,7 +538,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -553,7 +553,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_V1 = Weights( # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -568,7 +568,7 @@ class Raft_Large_Weights(WeightsEnum): # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -581,7 +581,7 @@ class Raft_Large_Weights(WeightsEnum): C_T_SKHT_K_V1 = Weights( # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -596,7 +596,7 @@ class Raft_Large_Weights(WeightsEnum): # Same as CT_SKHT with extra fine-tuning on Kitti # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 5257536, @@ -612,7 +612,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V1 = Weights( # Chairs + Things, ported from original paper repo (raft-small.pth) url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 990162, @@ -626,7 +626,7 @@ class Raft_Small_Weights(WeightsEnum): C_T_V2 = Weights( # Chairs + Things url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", - transforms=OpticalFlowEval, + transforms=OpticalFlow, meta={ **_COMMON_META, "num_params": 990162, diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index befc2299c06..9944e470352 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -7,7 +7,7 @@ from torch import Tensor from torch.nn import functional as F -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -109,7 +109,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class GoogLeNet_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "GoogLeNet", diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 697d99d4027..9a732f79fb7 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -9,7 +9,7 @@ from torchvision.models import inception as inception_module from torchvision.models.inception import InceptionOutputs, Inception_V3_Weights -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -175,7 +175,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class Inception_V3_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", - transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342), + transforms=partial(ImageClassification, crop_size=299, resize_size=342), meta={ "task": "image_classification", "architecture": "InceptionV3", diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 40f5cb544fd..1def3d24b28 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -7,7 +7,7 @@ from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, MobileNet_V2_Weights from ...ops.misc import Conv2dNormActivation -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -67,7 +67,7 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None: class MobileNet_V2_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV2", diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 4b79b7f26ae..4a203ca7095 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -6,7 +6,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub from ...ops.misc import Conv2dNormActivation, SqueezeExcitation -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -157,7 +157,7 @@ def _mobilenet_v3_model( class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): IMAGENET1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ "task": "image_classification", "architecture": "MobileNetV3", diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 666b1b23163..ab512a7413f 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -13,7 +13,7 @@ ResNeXt101_32X8D_Weights, ) -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -161,7 +161,7 @@ def _resnet( class ResNet18_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -178,7 +178,7 @@ class ResNet18_QuantizedWeights(WeightsEnum): class ResNet50_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -191,7 +191,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -208,7 +208,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -221,7 +221,7 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ) IMAGENET1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index c5bfe698636..a3a26120479 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -6,7 +6,7 @@ from torch import Tensor from torchvision.models import shufflenetv2 -from ...transforms import ImageClassificationEval, InterpolationMode +from ...transforms._presets import ImageClassification, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import handle_legacy_interface, _ovewrite_named_param @@ -118,7 +118,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -133,7 +133,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): IMAGENET1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 1015c21b858..72093686d84 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ..ops.misc import Conv2dNormActivation, SqueezeExcitation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -416,7 +416,7 @@ def _regnet( class RegNet_Y_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 4344144, @@ -427,7 +427,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 4344144, @@ -442,7 +442,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): class RegNet_Y_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 6432512, @@ -453,7 +453,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 6432512, @@ -468,7 +468,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): class RegNet_Y_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 11202430, @@ -479,7 +479,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 11202430, @@ -494,7 +494,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): class RegNet_Y_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 19436338, @@ -505,7 +505,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 19436338, @@ -520,7 +520,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): class RegNet_Y_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 39381472, @@ -531,7 +531,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39381472, @@ -546,7 +546,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): class RegNet_Y_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 83590140, @@ -557,7 +557,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 83590140, @@ -572,7 +572,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): class RegNet_Y_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 145046770, @@ -583,7 +583,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 145046770, @@ -603,7 +603,7 @@ class RegNet_Y_128GF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 5495976, @@ -614,7 +614,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 5495976, @@ -629,7 +629,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_800MF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 7259656, @@ -640,7 +640,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 7259656, @@ -655,7 +655,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): class RegNet_X_1_6GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 9190136, @@ -666,7 +666,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 9190136, @@ -681,7 +681,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): class RegNet_X_3_2GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 15296552, @@ -692,7 +692,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 15296552, @@ -707,7 +707,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): class RegNet_X_8GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 39572648, @@ -718,7 +718,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 39572648, @@ -733,7 +733,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): class RegNet_X_16GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 54278536, @@ -744,7 +744,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 54278536, @@ -759,7 +759,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): class RegNet_X_32GF_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 107811560, @@ -770,7 +770,7 @@ class RegNet_X_32GF_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "num_params": 107811560, diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 159749df006..8f44e553296 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -313,7 +313,7 @@ def _resnet( class ResNet18_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet18-f37072fd.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -330,7 +330,7 @@ class ResNet18_Weights(WeightsEnum): class ResNet34_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet34-b627a593.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -347,7 +347,7 @@ class ResNet34_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -360,7 +360,7 @@ class ResNet50_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -377,7 +377,7 @@ class ResNet50_Weights(WeightsEnum): class ResNet101_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet101-63fe2227.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -390,7 +390,7 @@ class ResNet101_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -407,7 +407,7 @@ class ResNet101_Weights(WeightsEnum): class ResNet152_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnet152-394f9c45.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNet", @@ -420,7 +420,7 @@ class ResNet152_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNet", @@ -437,7 +437,7 @@ class ResNet152_Weights(WeightsEnum): class ResNeXt50_32X4D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -450,7 +450,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -467,7 +467,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): class ResNeXt101_32X8D_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -480,7 +480,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "ResNeXt", @@ -497,7 +497,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): class Wide_ResNet50_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -510,7 +510,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -527,7 +527,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): class Wide_ResNet101_2_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "architecture": "WideResNet", @@ -540,7 +540,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ) IMAGENET1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232), + transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, "architecture": "WideResNet", diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 6e8bf0c398b..41ab34bae07 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import functional as F -from ...transforms import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentation, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param @@ -140,7 +140,7 @@ def _deeplabv3_resnet( class DeepLabV3_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 42004074, @@ -155,7 +155,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): class DeepLabV3_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 60996202, @@ -170,7 +170,7 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 11029328, diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 5a3ca1f654f..6a760be36dc 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -3,7 +3,7 @@ from torch import nn -from ...transforms import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentation, InterpolationMode from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import IntermediateLayerGetter, handle_legacy_interface, _ovewrite_value_param @@ -59,7 +59,7 @@ def __init__(self, in_channels: int, channels: int) -> None: class FCN_ResNet50_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 35322218, @@ -74,7 +74,7 @@ class FCN_ResNet50_Weights(WeightsEnum): class FCN_ResNet101_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ **_COMMON_META, "num_params": 54314346, diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index d1fe15a350d..33684526c6b 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -5,7 +5,7 @@ from torch import nn, Tensor from torch.nn import functional as F -from ...transforms import SemanticSegmentationEval, InterpolationMode +from ...transforms._presets import SemanticSegmentation, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES @@ -96,7 +96,7 @@ def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP: class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): COCO_WITH_VOC_LABELS_V1 = Weights( url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", - transforms=partial(SemanticSegmentationEval, resize_size=520), + transforms=partial(SemanticSegmentation, resize_size=520), meta={ "task": "image_semantic_segmentation", "architecture": "LRASPP", diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index b38c0ac2974..e988b819078 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -198,7 +198,7 @@ def _shufflenetv2( class ShuffleNet_V2_X0_5_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 1366792, @@ -212,7 +212,7 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): class ShuffleNet_V2_X1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 2278604, diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index d495b3148e5..bde8b5efcfd 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.init as init -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -128,7 +128,7 @@ def _squeezenet( class SqueezeNet1_0_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "min_size": (21, 21), @@ -143,7 +143,7 @@ class SqueezeNet1_0_Weights(WeightsEnum): class SqueezeNet1_1_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "min_size": (17, 17), diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 27325c9016c..93bfd5e6ba3 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -120,7 +120,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b class VGG11_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11-8a719046.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 132863336, @@ -134,7 +134,7 @@ class VGG11_Weights(WeightsEnum): class VGG11_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 132868840, @@ -148,7 +148,7 @@ class VGG11_BN_Weights(WeightsEnum): class VGG13_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13-19584684.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 133047848, @@ -162,7 +162,7 @@ class VGG13_Weights(WeightsEnum): class VGG13_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 133053736, @@ -176,7 +176,7 @@ class VGG13_BN_Weights(WeightsEnum): class VGG16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16-397923af.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 138357544, @@ -190,7 +190,7 @@ class VGG16_Weights(WeightsEnum): IMAGENET1K_FEATURES = Weights( url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", transforms=partial( - ImageClassificationEval, + ImageClassification, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), @@ -210,7 +210,7 @@ class VGG16_Weights(WeightsEnum): class VGG16_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 138365992, @@ -224,7 +224,7 @@ class VGG16_BN_Weights(WeightsEnum): class VGG19_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 143667240, @@ -238,7 +238,7 @@ class VGG19_Weights(WeightsEnum): class VGG19_BN_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 143678248, diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index a6b779d10f1..618ddb96ba2 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch import Tensor -from ...transforms import VideoClassificationEval, InterpolationMode +from ...transforms._presets import VideoClassification, InterpolationMode from ...utils import _log_api_usage_once from .._api import WeightsEnum, Weights from .._meta import _KINETICS400_CATEGORIES @@ -322,7 +322,7 @@ def _video_resnet( class R3D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R3D", @@ -337,7 +337,7 @@ class R3D_18_Weights(WeightsEnum): class MC3_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "MC3", @@ -352,7 +352,7 @@ class MC3_18_Weights(WeightsEnum): class R2Plus1D_18_Weights(WeightsEnum): KINETICS400_V1 = Weights( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", - transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)), + transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "architecture": "R(2+1)D", diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 801e7adc981..fb34cf3c8e1 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from ..ops.misc import Conv2dNormActivation -from ..transforms import ImageClassificationEval, InterpolationMode +from ..transforms._presets import ImageClassification, InterpolationMode from ..utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES @@ -317,7 +317,7 @@ def _vision_transformer( class ViT_B_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 86567656, @@ -334,7 +334,7 @@ class ViT_B_16_Weights(WeightsEnum): class ViT_B_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 88224232, @@ -351,7 +351,7 @@ class ViT_B_32_Weights(WeightsEnum): class ViT_L_16_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", - transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242), + transforms=partial(ImageClassification, crop_size=224, resize_size=242), meta={ **_COMMON_META, "num_params": 304326632, @@ -368,7 +368,7 @@ class ViT_L_16_Weights(WeightsEnum): class ViT_L_32_Weights(WeightsEnum): IMAGENET1K_V1 = Weights( url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", - transforms=partial(ImageClassificationEval, crop_size=224), + transforms=partial(ImageClassification, crop_size=224), meta={ **_COMMON_META, "num_params": 306535400, diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 94ec34ebe98..77680a14f0d 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1,9 +1,2 @@ from .transforms import * from .autoaugment import * -from ._presets import ( - ObjectDetectionEval, - ImageClassificationEval, - SemanticSegmentationEval, - VideoClassificationEval, - OpticalFlowEval, -) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 1776d876ccb..0bfb1cf9b38 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -1,4 +1,8 @@ -from typing import Dict, Optional, Tuple +""" +This file is part of the private API. Please do not use directly these classes as they will be modified on +future versions without warning. The classes should be accessed only via the transforms argument of Weights. +""" +from typing import Optional, Tuple import torch from torch import Tensor, nn @@ -7,24 +11,22 @@ __all__ = [ - "ObjectDetectionEval", - "ImageClassificationEval", - "VideoClassificationEval", - "SemanticSegmentationEval", - "OpticalFlowEval", + "ObjectDetection", + "ImageClassification", + "VideoClassification", + "SemanticSegmentation", + "OpticalFlow", ] -class ObjectDetectionEval(nn.Module): - def forward( - self, img: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: +class ObjectDetection(nn.Module): + def forward(self, img: Tensor) -> Tensor: if not isinstance(img, Tensor): img = F.pil_to_tensor(img) - return F.convert_image_dtype(img, torch.float), target + return F.convert_image_dtype(img, torch.float) -class ImageClassificationEval(nn.Module): +class ImageClassification(nn.Module): def __init__( self, crop_size: int, @@ -50,7 +52,7 @@ def forward(self, img: Tensor) -> Tensor: return img -class VideoClassificationEval(nn.Module): +class VideoClassification(nn.Module): def __init__( self, crop_size: Tuple[int, int], @@ -67,55 +69,59 @@ def __init__( self._interpolation = interpolation def forward(self, vid: Tensor) -> Tensor: - vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W) + need_squeeze = False + if vid.ndim < 5: + vid = vid.unsqueeze(dim=0) + need_squeeze = True + + vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W) + N, T, C, H, W = vid.shape + vid = vid.view(-1, C, H, W) vid = F.resize(vid, self._size, interpolation=self._interpolation) vid = F.center_crop(vid, self._crop_size) vid = F.convert_image_dtype(vid, torch.float) vid = F.normalize(vid, mean=self._mean, std=self._std) - return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) + vid = vid.view(N, T, C, H, W) + vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) + + if need_squeeze: + vid = vid.squeeze(dim=0) + return vid -class SemanticSegmentationEval(nn.Module): +class SemanticSegmentation(nn.Module): def __init__( self, resize_size: Optional[int], mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, - interpolation_target: InterpolationMode = InterpolationMode.NEAREST, ) -> None: super().__init__() self._size = [resize_size] if resize_size is not None else None self._mean = list(mean) self._std = list(std) self._interpolation = interpolation - self._interpolation_target = interpolation_target - def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + def forward(self, img: Tensor) -> Tensor: if isinstance(self._size, list): img = F.resize(img, self._size, interpolation=self._interpolation) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self._mean, std=self._std) - if target: - if isinstance(self._size, list): - target = F.resize(target, self._size, interpolation=self._interpolation_target) - if not isinstance(target, Tensor): - target = F.pil_to_tensor(target) - target = target.squeeze(0).to(torch.int64) - return img, target - + return img -class OpticalFlowEval(nn.Module): - def forward( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor] = None, valid_flow_mask: Optional[Tensor] = None - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask) +class OpticalFlow(nn.Module): + def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: + if not isinstance(img1, Tensor): + img1 = F.pil_to_tensor(img1) + if not isinstance(img2, Tensor): + img2 = F.pil_to_tensor(img2) - img1 = F.convert_image_dtype(img1, torch.float32) - img2 = F.convert_image_dtype(img2, torch.float32) + img1 = F.convert_image_dtype(img1, torch.float) + img2 = F.convert_image_dtype(img2, torch.float) # map [0, 1] into [-1, 1] img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) @@ -124,19 +130,4 @@ def forward( img1 = img1.contiguous() img2 = img2.contiguous() - return img1, img2, flow, valid_flow_mask - - def _pil_or_numpy_to_tensor( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] - ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - if not isinstance(img1, Tensor): - img1 = F.pil_to_tensor(img1) - if not isinstance(img2, Tensor): - img2 = F.pil_to_tensor(img2) - - if flow is not None and not isinstance(flow, Tensor): - flow = torch.from_numpy(flow) - if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor): - valid_flow_mask = torch.from_numpy(valid_flow_mask) - - return img1, img2, flow, valid_flow_mask + return img1, img2 From f51d5f8b34b66af415fa8ea02cfd09e4877b706e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 21 Mar 2022 12:39:47 +0000 Subject: [PATCH 36/45] Setting `weights_backbone` to its fully BC value (#5653) * Replace default `weights_backbone=None` with its BC values. * Fixing tests * Fix linter --- test/test_backbone_utils.py | 18 +++++++-------- test/test_models.py | 10 +++++---- .../test_models_detection_negative_samples.py | 22 ++++++++++++++----- test/test_models_detection_utils.py | 2 +- test/tracing/frcnn/trace_model.py | 2 +- .../models/detection/backbone_utils.py | 6 ++--- torchvision/models/detection/faster_rcnn.py | 6 ++--- torchvision/models/detection/fcos.py | 2 +- torchvision/models/detection/keypoint_rcnn.py | 2 +- torchvision/models/detection/mask_rcnn.py | 2 +- torchvision/models/detection/retinanet.py | 2 +- torchvision/models/detection/ssd.py | 2 +- torchvision/models/detection/ssdlite.py | 2 +- torchvision/models/segmentation/deeplabv3.py | 6 ++--- torchvision/models/segmentation/fcn.py | 4 ++-- torchvision/models/segmentation/lraspp.py | 2 +- 16 files changed, 51 insertions(+), 39 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index bee52c06075..60d8f8d167d 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -23,30 +23,30 @@ def get_available_models(): @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) def test_resnet_fpn_backbone(backbone_name): x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") - model = resnet_fpn_backbone(backbone_name=backbone_name) + model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None) assert isinstance(model, BackboneWithFPN) y = model(x) assert list(y.keys()) == ["0", "1", "2", "3", "pool"] with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): - resnet_fpn_backbone(backbone_name=backbone_name, trainable_layers=6) + resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - resnet_fpn_backbone(backbone_name=backbone_name, returned_layers=[0, 1, 2, 3]) + resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3]) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - resnet_fpn_backbone(backbone_name=backbone_name, returned_layers=[2, 3, 4, 5]) + resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5]) @pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small")) def test_mobilenet_backbone(backbone_name): with pytest.raises(ValueError, match=r"Trainable layers should be in the range"): - mobilenet_backbone(backbone_name=backbone_name, fpn=False, trainable_layers=-1) + mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - mobilenet_backbone(backbone_name=backbone_name, fpn=True, returned_layers=[-1, 0, 1, 2]) + mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2]) with pytest.raises(ValueError, match=r"Each returned layer should be in the range"): - mobilenet_backbone(backbone_name=backbone_name, fpn=True, returned_layers=[3, 4, 5, 6]) - model_fpn = mobilenet_backbone(backbone_name=backbone_name, fpn=True) + mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6]) + model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True) assert isinstance(model_fpn, BackboneWithFPN) - model = mobilenet_backbone(backbone_name=backbone_name, fpn=False) + model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False) assert isinstance(model, torch.nn.Sequential) diff --git a/test/test_models.py b/test/test_models.py index 0d45d61df13..bbd96c1c3a4 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -428,7 +428,7 @@ def test_inception_v3_eval(): def test_fasterrcnn_double(): - model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50) + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None) model.double() model.eval() input_shape = (3, 300, 300) @@ -467,7 +467,7 @@ def checkOut(out): assert "scores" in out[0] assert "labels" in out[0] - model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50) + model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, weights=None, weights_backbone=None) model.cuda() model.eval() input_shape = (3, 300, 300) @@ -583,6 +583,7 @@ def test_segmentation_model(model_fn, dev): set_rng_seed(0) defaults = { "num_classes": 10, + "weights_backbone": None, "input_shape": (1, 3, 32, 32), } model_name = model_fn.__name__ @@ -644,6 +645,7 @@ def test_detection_model(model_fn, dev): set_rng_seed(0) defaults = { "num_classes": 50, + "weights_backbone": None, "input_shape": (3, 300, 300), } model_name = model_fn.__name__ @@ -738,7 +740,7 @@ def compute_mean_std(tensor): @pytest.mark.parametrize("model_fn", get_models_from_module(models.detection)) def test_detection_model_validation(model_fn): set_rng_seed(0) - model = model_fn(num_classes=50) + model = model_fn(num_classes=50, weights=None, weights_backbone=None) input_shape = (3, 300, 300) x = [torch.rand(input_shape)] @@ -851,7 +853,7 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load max_trainable = _model_tests_values[model_name]["max_trainable"] n_trainable_params = [] for trainable_layers in range(0, max_trainable + 1): - model = model_fn(weights_backbone="DEFAULT", trainable_backbone_layers=trainable_layers) + model = model_fn(weights=None, weights_backbone="DEFAULT", trainable_backbone_layers=trainable_layers) n_trainable_params.append(len([p for p in model.parameters() if p.requires_grad])) assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index 3746c2f7920..c4efbd96cf3 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -99,7 +99,9 @@ def test_assign_targets_to_proposals(self): ], ) def test_forward_negative_sample_frcnn(self, name): - model = torchvision.models.detection.__dict__[name](num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.__dict__[name]( + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 + ) images, targets = self._make_empty_sample() loss_dict = model(images, targets) @@ -108,7 +110,9 @@ def test_forward_negative_sample_frcnn(self, name): assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0)) def test_forward_negative_sample_mrcnn(self): - model = torchvision.models.detection.maskrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.maskrcnn_resnet50_fpn( + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 + ) images, targets = self._make_empty_sample(add_masks=True) loss_dict = model(images, targets) @@ -118,7 +122,9 @@ def test_forward_negative_sample_mrcnn(self): assert_equal(loss_dict["loss_mask"], torch.tensor(0.0)) def test_forward_negative_sample_krcnn(self): - model = torchvision.models.detection.keypointrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.keypointrcnn_resnet50_fpn( + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 + ) images, targets = self._make_empty_sample(add_keypoints=True) loss_dict = model(images, targets) @@ -128,7 +134,9 @@ def test_forward_negative_sample_krcnn(self): assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.0)) def test_forward_negative_sample_retinanet(self): - model = torchvision.models.detection.retinanet_resnet50_fpn(num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.retinanet_resnet50_fpn( + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 + ) images, targets = self._make_empty_sample() loss_dict = model(images, targets) @@ -136,7 +144,9 @@ def test_forward_negative_sample_retinanet(self): assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0)) def test_forward_negative_sample_fcos(self): - model = torchvision.models.detection.fcos_resnet50_fpn(num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.fcos_resnet50_fpn( + weights=None, weights_backbone=None, num_classes=2, min_size=100, max_size=100 + ) images, targets = self._make_empty_sample() loss_dict = model(images, targets) @@ -145,7 +155,7 @@ def test_forward_negative_sample_fcos(self): assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0)) def test_forward_negative_sample_ssd(self): - model = torchvision.models.detection.ssd300_vgg16(num_classes=2) + model = torchvision.models.detection.ssd300_vgg16(weights=None, weights_backbone=None, num_classes=2) images, targets = self._make_empty_sample() loss_dict = model(images, targets) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index 5cfc7e04d3f..a160113cbbf 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -40,7 +40,7 @@ def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): # be frozen for each trainable_backbone_layers parameter value # i.e all 53 params are frozen if trainable_backbone_layers=0 # ad first 24 params are frozen if trainable_backbone_layers=2 - model = backbone_utils.resnet_fpn_backbone("resnet50", trainable_layers=train_layers) + model = backbone_utils.resnet_fpn_backbone("resnet50", weights=None, trainable_layers=train_layers) # boolean list that is true if the param at that index is frozen is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] # check that expected initial number of layers are frozen diff --git a/test/tracing/frcnn/trace_model.py b/test/tracing/frcnn/trace_model.py index 768954d29b2..b5ec50bdab1 100644 --- a/test/tracing/frcnn/trace_model.py +++ b/test/tracing/frcnn/trace_model.py @@ -6,7 +6,7 @@ HERE = osp.dirname(osp.abspath(__file__)) ASSETS = osp.dirname(osp.dirname(HERE)) -model = torchvision.models.detection.fasterrcnn_resnet50_fpn() +model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None) model.eval() traced_model = torch.jit.script(model) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 20b78254fcc..b767756692b 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -62,7 +62,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: def resnet_fpn_backbone( *, backbone_name: str, - weights: Optional[WeightsEnum] = None, + weights: Optional[WeightsEnum], norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 3, returned_layers: Optional[List[int]] = None, @@ -171,8 +171,8 @@ def _validate_trainable_layers( def mobilenet_backbone( *, backbone_name: str, - weights: Optional[WeightsEnum] = None, - fpn: bool = True, + weights: Optional[WeightsEnum], + fpn: bool, norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 2, returned_layers: Optional[List[int]] = None, diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 7d18fbe90a3..2c1e6358c58 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -384,7 +384,7 @@ def fasterrcnn_resnet50_fpn( weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: @@ -529,7 +529,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: @@ -586,7 +586,7 @@ def fasterrcnn_mobilenet_v3_large_fpn( weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 27e54a565f2..7b1b3f87ba8 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -670,7 +670,7 @@ def fcos_resnet50_fpn( weights: Optional[FCOS_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FCOS: diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 2a554a6f56e..dc03c693e1c 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -356,7 +356,7 @@ def keypointrcnn_resnet50_fpn( progress: bool = True, num_classes: Optional[int] = None, num_keypoints: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> KeypointRCNN: diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index fb60ffcbb0a..a6cb731c0df 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -333,7 +333,7 @@ def maskrcnn_resnet50_fpn( weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> MaskRCNN: diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 49b9acf45e4..2242c1e09bb 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -612,7 +612,7 @@ def retinanet_resnet50_fpn( weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> RetinaNet: diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index c30919e621c..a3b8ffda178 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -570,7 +570,7 @@ def ssd300_vgg16( weights: Optional[SSD300_VGG16_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[VGG16_Weights] = None, + weights_backbone: Optional[VGG16_Weights] = VGG16_Weights.IMAGENET1K_FEATURES, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> SSD: diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 93023337d11..5fb5b402fef 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -212,7 +212,7 @@ def ssdlite320_mobilenet_v3_large( weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, trainable_backbone_layers: Optional[int] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any, diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 41ab34bae07..092a81f643b 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -215,7 +215,7 @@ def deeplabv3_resnet50( progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50_Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-50 backbone. @@ -256,7 +256,7 @@ def deeplabv3_resnet101( progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101_Weights] = None, + weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1, **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a ResNet-101 backbone. @@ -297,7 +297,7 @@ def deeplabv3_mobilenet_v3_large( progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, **kwargs: Any, ) -> DeepLabV3: """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 6a760be36dc..6b6d14ffe32 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -111,7 +111,7 @@ def fcn_resnet50( progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50_Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, **kwargs: Any, ) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. @@ -152,7 +152,7 @@ def fcn_resnet101( progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101_Weights] = None, + weights_backbone: Optional[ResNet101_Weights] = ResNet101_Weights.IMAGENET1K_V1, **kwargs: Any, ) -> FCN: """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 33684526c6b..fc6d14d366b 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -121,7 +121,7 @@ def lraspp_mobilenet_v3_large( weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1, **kwargs: Any, ) -> LRASPP: """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. From 8b9631b861da917cabeeb04de629b08fd9a06b4c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 21 Mar 2022 13:38:54 +0000 Subject: [PATCH 37/45] Update docs. --- references/classification/utils.py | 8 ++++---- torchvision/models/detection/ssdlite.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/references/classification/utils.py b/references/classification/utils.py index 27398d97234..32658a7c137 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -330,22 +330,22 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T from torchvision import models as M # Classification - model = M.mobilenet_v3_large() + model = M.mobilenet_v3_large(weights=None) print(store_model_weights(model, './class.pth')) # Quantized Classification - model = M.quantization.mobilenet_v3_large(quantize=False) + model = M.quantization.mobilenet_v3_large(weights=None, quantize=False) model.fuse_model(is_qat=True) model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') _ = torch.ao.quantization.prepare_qat(model, inplace=True) print(store_model_weights(model, './qat.pth')) # Object Detection - model = M.detection.fasterrcnn_mobilenet_v3_large_fpn() + model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None) print(store_model_weights(model, './obj.pth')) # Segmentation - model = M.segmentation.deeplabv3_mobilenet_v3_large(aux_loss=True) + model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True) print(store_model_weights(model, './segm.pth', strict=False)) Args: diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 5fb5b402fef..2e890356417 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -291,7 +291,7 @@ def ssdlite320_mobilenet_v3_large( "detections_per_img": 300, "topk_candidates": 300, # Rescale the input in a way compatible to the backbone: - # The following mean/std rescale the data from [0, 1] to [-1, -1] + # The following mean/std rescale the data from [0, 1] to [-1, 1] "image_mean": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5], } From fea9e39b7170b216a3f328046b9213b634641f6e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 21 Mar 2022 14:59:05 +0000 Subject: [PATCH 38/45] Update preprocessing on reference scripts. --- references/detection/train.py | 2 +- references/optical_flow/train.py | 5 ++++- references/segmentation/train.py | 9 ++++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 0e0a0d70fad..a309bb2a7ea 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -47,7 +47,7 @@ def get_transform(train, args): elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - return lambda img, target=None: (trans(img), target) + return lambda img, target: (trans(img), target) else: return presets.DetectionPresetEval() diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 1a50d1c617d..d665997e7c9 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -136,7 +136,10 @@ def evaluate(model, args): if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - preprocessing = lambda img1, img2, flow=None, valid=None: trans(img1, img2) + (flow, valid) # noqa: E731 + preprocessing = lambda img1, img2, flow, valid: trans(img1, img2) + ( # noqa: E731 + torch.from_numpy(flow), + torch.from_numpy(valid), + ) else: preprocessing = OpticalFlowPresetEval() diff --git a/references/segmentation/train.py b/references/segmentation/train.py index b4e55acd407..23d8e87bbc0 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -9,6 +9,7 @@ import utils from coco_utils import get_coco from torch import nn +from torchvision.transforms import functional as F, InterpolationMode def get_dataset(dir_path, name, image_set, transform): @@ -32,7 +33,13 @@ def get_transform(train, args): elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - return lambda img, target=None: (trans(img), target) + + def preprocess(img, target): + img = trans(img) + size = F.get_dimensions(img)[1:] + target = F.resize(target, size, interpolation=InterpolationMode.NEAREST) + return img, F.pil_to_tensor(target) + return preprocess else: return presets.SegmentationPresetEval(base_size=520) From 59bdfbdfae0ae1bdb2740748431fa85d8466d858 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 21 Mar 2022 15:15:29 +0000 Subject: [PATCH 39/45] Change qat/ptq to their full values. --- torchvision/models/quantization/googlenet.py | 2 +- torchvision/models/quantization/inception.py | 2 +- torchvision/models/quantization/mobilenetv2.py | 2 +- torchvision/models/quantization/mobilenetv3.py | 2 +- torchvision/models/quantization/resnet.py | 2 +- torchvision/models/quantization/shufflenetv2.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 9944e470352..1794c834eea 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -120,7 +120,7 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, "backend": "fbgemm", - "quantization": "ptq", + "quantization": "Post Training Quantization", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", "unquantized": GoogLeNet_Weights.IMAGENET1K_V1, "acc@1": 69.826, diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 9a732f79fb7..ff5c9a37365 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -186,7 +186,7 @@ class Inception_V3_QuantizedWeights(WeightsEnum): "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, "backend": "fbgemm", - "quantization": "ptq", + "quantization": "Post Training Quantization", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", "unquantized": Inception_V3_Weights.IMAGENET1K_V1, "acc@1": 77.176, diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 1def3d24b28..d9554e0ba9f 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -78,7 +78,7 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, "backend": "qnnpack", - "quantization": "qat", + "quantization": "Quantization Aware Training", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1, "acc@1": 71.658, diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 4a203ca7095..88907ec210a 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -168,7 +168,7 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, "backend": "qnnpack", - "quantization": "qat", + "quantization": "Quantization Aware Training", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1, "acc@1": 73.004, diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index ab512a7413f..a781f320000 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -153,7 +153,7 @@ def _resnet( "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, "backend": "fbgemm", - "quantization": "ptq", + "quantization": "Post Training Quantization", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", } diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index a3a26120479..1f4f1890e07 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -110,7 +110,7 @@ def _shufflenetv2( "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR, "backend": "fbgemm", - "quantization": "ptq", + "quantization": "Post Training Quantization", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", } From e85f4947fddf172059f7eb255a0ffc2df44281e4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 21 Mar 2022 17:18:02 +0000 Subject: [PATCH 40/45] Refactoring preprocessing --- references/optical_flow/train.py | 13 +++++++++---- references/segmentation/train.py | 5 +++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index d665997e7c9..5070cb554d4 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -136,10 +136,15 @@ def evaluate(model, args): if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - preprocessing = lambda img1, img2, flow, valid: trans(img1, img2) + ( # noqa: E731 - torch.from_numpy(flow), - torch.from_numpy(valid), - ) + + def preprocessing(img1, img2, flow, valid_flow_mask): + img1, img2 = trans(img1, img2) + if flow is not None and not isinstance(flow, torch.Tensor): + flow = torch.from_numpy(flow) + if valid_flow_mask is not None and not isinstance(valid_flow_mask, torch.Tensor): + valid_flow_mask = torch.from_numpy(valid_flow_mask) + return img1, img2, flow, valid_flow_mask + else: preprocessing = OpticalFlowPresetEval() diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 23d8e87bbc0..e8570ab7f69 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -34,12 +34,13 @@ def get_transform(train, args): weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() - def preprocess(img, target): + def preprocessing(img, target): img = trans(img) size = F.get_dimensions(img)[1:] target = F.resize(target, size, interpolation=InterpolationMode.NEAREST) return img, F.pil_to_tensor(target) - return preprocess + + return preprocessing else: return presets.SegmentationPresetEval(base_size=520) From 0b0fc82a8c5cdd437c96603193faea3d999119e6 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 21 Mar 2022 18:51:12 +0000 Subject: [PATCH 41/45] Fix video preset --- references/video_classification/train.py | 2 -- torchvision/transforms/_presets.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index da7ef9fc607..918a012282e 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -96,8 +96,6 @@ def main(args): utils.init_distributed_mode(args) print(args) - print("torch version: ", torch.__version__) - print("torchvision version: ", torchvision.__version__) device = torch.device(args.device) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 0bfb1cf9b38..42575024827 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -81,6 +81,7 @@ def forward(self, vid: Tensor) -> Tensor: vid = F.center_crop(vid, self._crop_size) vid = F.convert_image_dtype(vid, torch.float) vid = F.normalize(vid, mean=self._mean, std=self._std) + H, W = self._crop_size vid = vid.view(N, T, C, H, W) vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) From 903a3d534c0f8fdf2077f7ffb979a559c30ef6bb Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 21 Mar 2022 20:15:11 +0000 Subject: [PATCH 42/45] No initialization on VGG if pretrained --- torchvision/models/vgg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 93bfd5e6ba3..c245eef6482 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -97,6 +97,7 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: if weights is not None: + kwargs["init_weights"] = False if weights.meta["categories"] is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) From db9139a79888875923f6d305fe38d1c04a7e0650 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 22 Mar 2022 13:06:10 +0000 Subject: [PATCH 43/45] Fix warning messages for backbone utils. --- torchvision/models/_api.py | 36 +++++++++++++++++++ .../models/detection/backbone_utils.py | 12 +++++-- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index d841415a45a..7cd97e5ad6e 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -3,6 +3,7 @@ import sys from collections import OrderedDict from dataclasses import dataclass, fields +from inspect import signature from typing import Any, Callable, Dict from torchvision._utils import StrEnum @@ -105,3 +106,38 @@ def get_weight(name: str) -> WeightsEnum: raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") return weights_enum.from_str(value_name) + + +def get_enum_from_fn(fn: Callable) -> WeightsEnum: + """ + Internal method that gets the weight enum of a specific model builder method. + Might be removed after the handle_legacy_interface is removed. + + Args: + fn (Callable): The builder method used to create the model. + weight_name (str): The name of the weight enum entry of the specific model. + Returns: + WeightsEnum: The requested weight enum. + """ + sig = signature(fn) + if "weights" not in sig.parameters: + raise ValueError("The method is missing the 'weights' argument.") + + ann = signature(fn).parameters["weights"].annotation + weights_enum = None + if isinstance(ann, type) and issubclass(ann, WeightsEnum): + weights_enum = ann + else: + # handle cases like Union[Optional, T] + # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 + for t in ann.__args__: # type: ignore[union-attr] + if isinstance(t, type) and issubclass(t, WeightsEnum): + weights_enum = t + break + + if weights_enum is None: + raise ValueError( + "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." + ) + + return weights_enum diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index b767756692b..24215322b84 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -6,7 +6,7 @@ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from .. import mobilenet, resnet -from .._api import WeightsEnum +from .._api import WeightsEnum, get_enum_from_fn from .._utils import IntermediateLayerGetter, handle_legacy_interface @@ -57,7 +57,10 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: @handle_legacy_interface( - weights=("pretrained", True), # type: ignore[arg-type] + weights=( + "pretrained", + lambda kwargs: get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), + ), ) def resnet_fpn_backbone( *, @@ -166,7 +169,10 @@ def _validate_trainable_layers( @handle_legacy_interface( - weights=("pretrained", True), # type: ignore[arg-type] + weights=( + "pretrained", + lambda kwargs: get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), + ), ) def mobilenet_backbone( *, From 545e4318f5552d84155dfaf17e249eb6bf4fe09d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 22 Mar 2022 13:56:23 +0000 Subject: [PATCH 44/45] Adding star to all preset constructors. --- references/classification/presets.py | 2 ++ references/detection/train.py | 2 +- references/optical_flow/presets.py | 1 + references/segmentation/presets.py | 4 ++-- references/video_classification/presets.py | 3 ++- torchvision/transforms/_presets.py | 3 +++ 6 files changed, 11 insertions(+), 4 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 418ef3e2e07..6bc38e72953 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -6,6 +6,7 @@ class ClassificationPresetTrain: def __init__( self, + *, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), @@ -46,6 +47,7 @@ def __call__(self, img): class ClassificationPresetEval: def __init__( self, + *, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), diff --git a/references/detection/train.py b/references/detection/train.py index a309bb2a7ea..b6634061503 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -43,7 +43,7 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: - return presets.DetectionPresetTrain(args.data_augmentation) + return presets.DetectionPresetTrain(data_augmentation=args.data_augmentation) elif args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() diff --git a/references/optical_flow/presets.py b/references/optical_flow/presets.py index 43ff4a24f3b..32d9542e692 100644 --- a/references/optical_flow/presets.py +++ b/references/optical_flow/presets.py @@ -22,6 +22,7 @@ def forward(self, img1, img2, flow, valid): class OpticalFlowPresetTrain(torch.nn.Module): def __init__( self, + *, # RandomResizeAndCrop params crop_size, min_scale=-0.2, diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index 8cada98ac95..ed02ae660e4 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -3,7 +3,7 @@ class SegmentationPresetTrain: - def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + def __init__(self, *, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): min_size = int(0.5 * base_size) max_size = int(2.0 * base_size) @@ -25,7 +25,7 @@ def __call__(self, img, target): class SegmentationPresetEval: - def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + def __init__(self, *, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): self.transforms = T.Compose( [ T.RandomResize(base_size, base_size), diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index d24169e42dd..c12d00a022b 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -6,6 +6,7 @@ class VideoClassificationPresetTrain: def __init__( self, + *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), @@ -27,7 +28,7 @@ def __call__(self, x): class VideoClassificationPresetEval: - def __init__(self, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): + def __init__(self, *, crop_size, resize_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): self.transforms = transforms.Compose( [ ConvertBHWCtoBCHW(), diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index 42575024827..4d503f44cc5 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -29,6 +29,7 @@ def forward(self, img: Tensor) -> Tensor: class ImageClassification(nn.Module): def __init__( self, + *, crop_size: int, resize_size: int = 256, mean: Tuple[float, ...] = (0.485, 0.456, 0.406), @@ -55,6 +56,7 @@ def forward(self, img: Tensor) -> Tensor: class VideoClassification(nn.Module): def __init__( self, + *, crop_size: Tuple[int, int], resize_size: Tuple[int, int], mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645), @@ -93,6 +95,7 @@ def forward(self, vid: Tensor) -> Tensor: class SemanticSegmentation(nn.Module): def __init__( self, + *, resize_size: Optional[int], mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), From ff3fd54099ac6e1305ba3089e68f1b2c1d0b839d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 22 Mar 2022 15:21:51 +0000 Subject: [PATCH 45/45] Fix mypy. --- torchvision/models/_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/_api.py b/torchvision/models/_api.py index 7cd97e5ad6e..e47eaf73aab 100644 --- a/torchvision/models/_api.py +++ b/torchvision/models/_api.py @@ -4,7 +4,7 @@ from collections import OrderedDict from dataclasses import dataclass, fields from inspect import signature -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, cast from torchvision._utils import StrEnum @@ -140,4 +140,4 @@ def get_enum_from_fn(fn: Callable) -> WeightsEnum: "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." ) - return weights_enum + return cast(WeightsEnum, weights_enum)