From d3fa35439f815376d223d4cbc59180a9523aa06e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 29 Nov 2021 12:38:37 +0000 Subject: [PATCH 1/4] Rename classes Weights => WeightsEnum and WeightEntry => Weights. --- test/test_prototype_models.py | 1 + torchvision/prototype/models/_api.py | 36 +++++------ torchvision/prototype/models/_utils.py | 4 +- torchvision/prototype/models/alexnet.py | 6 +- torchvision/prototype/models/densenet.py | 22 +++---- .../prototype/models/detection/faster_rcnn.py | 14 ++--- .../models/detection/keypoint_rcnn.py | 8 +-- .../prototype/models/detection/mask_rcnn.py | 6 +- .../prototype/models/detection/retinanet.py | 6 +- torchvision/prototype/models/detection/ssd.py | 6 +- .../prototype/models/detection/ssdlite.py | 6 +- torchvision/prototype/models/efficientnet.py | 36 +++++------ torchvision/prototype/models/googlenet.py | 6 +- torchvision/prototype/models/inception.py | 6 +- torchvision/prototype/models/mnasnet.py | 16 ++--- torchvision/prototype/models/mobilenetv2.py | 6 +- torchvision/prototype/models/mobilenetv3.py | 14 ++--- .../models/quantization/googlenet.py | 6 +- .../models/quantization/inception.py | 6 +- .../models/quantization/mobilenetv2.py | 6 +- .../models/quantization/mobilenetv3.py | 8 +-- .../prototype/models/quantization/resnet.py | 20 +++---- .../models/quantization/shufflenetv2.py | 12 ++-- torchvision/prototype/models/regnet.py | 60 +++++++++---------- torchvision/prototype/models/resnet.py | 54 ++++++++--------- .../models/segmentation/deeplabv3.py | 14 ++--- .../prototype/models/segmentation/fcn.py | 10 ++-- .../prototype/models/segmentation/lraspp.py | 6 +- torchvision/prototype/models/shufflenetv2.py | 16 ++--- torchvision/prototype/models/squeezenet.py | 10 ++-- torchvision/prototype/models/vgg.py | 38 ++++++------ torchvision/prototype/models/video/resnet.py | 16 ++--- .../prototype/models/vision_transformer.py | 12 ++-- 33 files changed, 247 insertions(+), 246 deletions(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index f53299bcf51..81d7e30863a 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -39,6 +39,7 @@ def get_models_with_module_names(module): [ (models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1), (models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2), + (models.quantization.resnet50, "default", models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV2), ( models.quantization.resnet50, "ImageNet1K_FBGEMM_RefV1", diff --git a/torchvision/prototype/models/_api.py b/torchvision/prototype/models/_api.py index 6291f8eedcf..2935039e087 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/prototype/models/_api.py @@ -7,11 +7,11 @@ from ..._internally_replaced_utils import load_state_dict_from_url -__all__ = ["Weights", "WeightEntry", "get_weight"] +__all__ = ["WeightsEnum", "Weights", "get_weight"] @dataclass -class WeightEntry: +class Weights: """ This class is used to group important attributes associated with the pre-trained weights. @@ -33,17 +33,17 @@ class WeightEntry: default: bool -class Weights(Enum): +class WeightsEnum(Enum): """ This class is the parent class of all model weights. Each model building method receives an optional `weights` parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type - `WeightEntry`. + `Weights`. Args: - value (WeightEntry): The data class entry with the weight information. + value (Weights): The data class entry with the weight information. """ - def __init__(self, value: WeightEntry): + def __init__(self, value: Weights): self._value_ = value @classmethod @@ -58,7 +58,7 @@ def verify(cls, obj: Any) -> Any: return obj @classmethod - def from_str(cls, value: str) -> "Weights": + def from_str(cls, value: str) -> "WeightsEnum": for v in cls: if v._name_ == value or (value == "default" and v.default): return v @@ -71,14 +71,14 @@ def __repr__(self): return f"{self.__class__.__name__}.{self._name_}" def __getattr__(self, name): - # Be able to fetch WeightEntry attributes directly - for f in fields(WeightEntry): + # Be able to fetch Weights attributes directly + for f in fields(Weights): if f.name == name: return object.__getattribute__(self.value, name) return super().__getattr__(name) -def get_weight(fn: Callable, weight_name: str) -> Weights: +def get_weight(fn: Callable, weight_name: str) -> WeightsEnum: """ Gets the weight enum of a specific model builder method and weight name combination. @@ -87,32 +87,32 @@ def get_weight(fn: Callable, weight_name: str) -> Weights: weight_name (str): The name of the weight enum entry of the specific model. Returns: - Weights: The requested weight enum. + WeightsEnum: The requested weight enum. """ sig = signature(fn) if "weights" not in sig.parameters: raise ValueError("The method is missing the 'weights' parameter.") ann = signature(fn).parameters["weights"].annotation - weights_class = None - if isinstance(ann, type) and issubclass(ann, Weights): - weights_class = ann + weights_enum = None + if isinstance(ann, type) and issubclass(ann, WeightsEnum): + weights_enum = ann else: # handle cases like Union[Optional, T] # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 for t in ann.__args__: # type: ignore[union-attr] - if isinstance(t, type) and issubclass(t, Weights): + if isinstance(t, type) and issubclass(t, WeightsEnum): # ensure the name exists. handles builders with multiple types of weights like in quantization try: t.from_str(weight_name) except ValueError: continue - weights_class = t + weights_enum = t break - if weights_class is None: + if weights_enum is None: raise ValueError( "The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct." ) - return weights_class.from_str(weight_name) + return weights_enum.from_str(weight_name) diff --git a/torchvision/prototype/models/_utils.py b/torchvision/prototype/models/_utils.py index 31d9703db1f..e2ee9034953 100644 --- a/torchvision/prototype/models/_utils.py +++ b/torchvision/prototype/models/_utils.py @@ -1,10 +1,10 @@ import warnings from typing import Any, Dict, Optional, TypeVar -from ._api import Weights +from ._api import WeightsEnum -W = TypeVar("W", bound=Weights) +W = TypeVar("W", bound=WeightsEnum) V = TypeVar("V") diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index d5e826b4559..d6df916050d 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.alexnet import AlexNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -13,8 +13,8 @@ __all__ = ["AlexNet", "AlexNetWeights", "alexnet"] -class AlexNetWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class AlexNetWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index f70cee3a528..18c091ff786 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -7,7 +7,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.densenet import DenseNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -25,7 +25,7 @@ ] -def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None: +def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None: # '.'s are no longer allowed in module names, but previous _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used @@ -48,7 +48,7 @@ def _densenet( growth_rate: int, block_config: Tuple[int, int, int, int], num_init_features: int, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> DenseNet: @@ -71,8 +71,8 @@ def _densenet( } -class DenseNet121Weights(Weights): - ImageNet1K_Community = WeightEntry( +class DenseNet121Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -84,8 +84,8 @@ class DenseNet121Weights(Weights): ) -class DenseNet161Weights(Weights): - ImageNet1K_Community = WeightEntry( +class DenseNet161Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/densenet161-8d451a50.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -97,8 +97,8 @@ class DenseNet161Weights(Weights): ) -class DenseNet169Weights(Weights): - ImageNet1K_Community = WeightEntry( +class DenseNet169Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -110,8 +110,8 @@ class DenseNet169Weights(Weights): ) -class DenseNet201Weights(Weights): - ImageNet1K_Community = WeightEntry( +class DenseNet201Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/densenet201-c1103571.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 685a309beaa..241466b8eba 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -12,7 +12,7 @@ misc_nn_ops, overwrite_eps, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large @@ -36,8 +36,8 @@ } -class FasterRCNNResNet50FPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class FasterRCNNResNet50FPNWeights(WeightsEnum): + Coco_RefV1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", transforms=CocoEval, meta={ @@ -49,8 +49,8 @@ class FasterRCNNResNet50FPNWeights(Weights): ) -class FasterRCNNMobileNetV3LargeFPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class FasterRCNNMobileNetV3LargeFPNWeights(WeightsEnum): + Coco_RefV1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", transforms=CocoEval, meta={ @@ -62,8 +62,8 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights): ) -class FasterRCNNMobileNetV3Large320FPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class FasterRCNNMobileNetV3Large320FPNWeights(WeightsEnum): + Coco_RefV1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", transforms=CocoEval, meta={ diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index d45d7f93af7..36fd955f73a 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -9,7 +9,7 @@ misc_nn_ops, overwrite_eps, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from ..resnet import ResNet50Weights, resnet50 @@ -25,8 +25,8 @@ _COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES} -class KeypointRCNNResNet50FPNWeights(Weights): - Coco_RefV1_Legacy = WeightEntry( +class KeypointRCNNResNet50FPNWeights(WeightsEnum): + Coco_RefV1_Legacy = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", transforms=CocoEval, meta={ @@ -37,7 +37,7 @@ class KeypointRCNNResNet50FPNWeights(Weights): }, default=False, ) - Coco_RefV1 = WeightEntry( + Coco_RefV1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", transforms=CocoEval, meta={ diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index 8aba8ce5041..be6742e4879 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -10,7 +10,7 @@ misc_nn_ops, overwrite_eps, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from ..resnet import ResNet50Weights, resnet50 @@ -23,8 +23,8 @@ ] -class MaskRCNNResNet50FPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class MaskRCNNResNet50FPNWeights(WeightsEnum): + Coco_RefV1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", transforms=CocoEval, meta={ diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index c9361934921..490a9b9e0c2 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -11,7 +11,7 @@ misc_nn_ops, overwrite_eps, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from ..resnet import ResNet50Weights, resnet50 @@ -24,8 +24,8 @@ ] -class RetinaNetResNet50FPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class RetinaNetResNet50FPNWeights(WeightsEnum): + Coco_RefV1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", transforms=CocoEval, meta={ diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 09ce083bf7e..d2a51508366 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -10,7 +10,7 @@ DefaultBoxGenerator, SSD, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from ..vgg import VGG16Weights, vgg16 @@ -22,8 +22,8 @@ ] -class SSD300VGG16Weights(Weights): - Coco_RefV1 = WeightEntry( +class SSD300VGG16Weights(WeightsEnum): + Coco_RefV1 = Weights( url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", transforms=CocoEval, meta={ diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 0e2786ea203..5f597ba0951 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -15,7 +15,7 @@ SSD, SSDLiteHead, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large @@ -27,8 +27,8 @@ ] -class SSDlite320MobileNetV3LargeFPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class SSDlite320MobileNetV3LargeFPNWeights(WeightsEnum): + Coco_RefV1 = Weights( url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", transforms=CocoEval, meta={ diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 66dbd8ef5ea..16bf77cbe6f 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -6,7 +6,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.efficientnet import EfficientNet, MBConvConfig -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -36,7 +36,7 @@ def _efficientnet( width_mult: float, depth_mult: float, dropout: float, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> EfficientNet: @@ -69,8 +69,8 @@ def _efficientnet( } -class EfficientNetB0Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNetB0Weights(WeightsEnum): + ImageNet1K_TimmV1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), meta={ @@ -83,8 +83,8 @@ class EfficientNetB0Weights(Weights): ) -class EfficientNetB1Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNetB1Weights(WeightsEnum): + ImageNet1K_TimmV1 = Weights( url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), meta={ @@ -97,8 +97,8 @@ class EfficientNetB1Weights(Weights): ) -class EfficientNetB2Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNetB2Weights(WeightsEnum): + ImageNet1K_TimmV1 = Weights( url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), meta={ @@ -111,8 +111,8 @@ class EfficientNetB2Weights(Weights): ) -class EfficientNetB3Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNetB3Weights(WeightsEnum): + ImageNet1K_TimmV1 = Weights( url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), meta={ @@ -125,8 +125,8 @@ class EfficientNetB3Weights(Weights): ) -class EfficientNetB4Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNetB4Weights(WeightsEnum): + ImageNet1K_TimmV1 = Weights( url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), meta={ @@ -139,8 +139,8 @@ class EfficientNetB4Weights(Weights): ) -class EfficientNetB5Weights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class EfficientNetB5Weights(WeightsEnum): + ImageNet1K_TFV1 = Weights( url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), meta={ @@ -153,8 +153,8 @@ class EfficientNetB5Weights(Weights): ) -class EfficientNetB6Weights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class EfficientNetB6Weights(WeightsEnum): + ImageNet1K_TFV1 = Weights( url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), meta={ @@ -167,8 +167,8 @@ class EfficientNetB6Weights(Weights): ) -class EfficientNetB7Weights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class EfficientNetB7Weights(WeightsEnum): + ImageNet1K_TFV1 = Weights( url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), meta={ diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 8a65aa08d35..260259af903 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -6,7 +6,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -14,8 +14,8 @@ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"] -class GoogLeNetWeights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class GoogLeNetWeights(WeightsEnum): + ImageNet1K_TFV1 = Weights( url="https://download.pytorch.org/models/googlenet-1378be20.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index fe5cde184e0..d82d459b1d4 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -13,8 +13,8 @@ __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"] -class InceptionV3Weights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class InceptionV3Weights(WeightsEnum): + ImageNet1K_TFV1 = Weights( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", transforms=partial(ImageNetEval, crop_size=299, resize_size=342), meta={ diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index c5cef01fb98..a109b2ef50b 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.mnasnet import MNASNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -31,8 +31,8 @@ } -class MNASNet0_5Weights(Weights): - ImageNet1K_Community = WeightEntry( +class MNASNet0_5Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -44,13 +44,13 @@ class MNASNet0_5Weights(Weights): ) -class MNASNet0_75Weights(Weights): +class MNASNet0_75Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in mnasnet0_75 pass -class MNASNet1_0Weights(Weights): - ImageNet1K_Community = WeightEntry( +class MNASNet1_0Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -62,12 +62,12 @@ class MNASNet1_0Weights(Weights): ) -class MNASNet1_3Weights(Weights): +class MNASNet1_3Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in mnasnet1_3 pass -def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: Any) -> MNASNet: +def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 1e33384ad2d..d0649f7a5bc 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv2 import MobileNetV2 -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -13,8 +13,8 @@ __all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"] -class MobileNetV2Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class MobileNetV2Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index 1a6a810856b..6b2c312d5cb 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -22,7 +22,7 @@ def _mobilenet_v3( inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> MobileNetV3: @@ -44,8 +44,8 @@ def _mobilenet_v3( } -class MobileNetV3LargeWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class MobileNetV3LargeWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -56,7 +56,7 @@ class MobileNetV3LargeWeights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_RefV2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -69,8 +69,8 @@ class MobileNetV3LargeWeights(Weights): ) -class MobileNetV3SmallWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class MobileNetV3SmallWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index 4769b115ba5..8340606b2d5 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -10,7 +10,7 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ..googlenet import GoogLeNetWeights @@ -23,8 +23,8 @@ ] -class QuantizedGoogLeNetWeights(Weights): - ImageNet1K_FBGEMM_TFV1 = WeightEntry( +class QuantizedGoogLeNetWeights(WeightsEnum): + ImageNet1K_FBGEMM_TFV1 = Weights( url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index 4949d0f4b2d..01c310e3271 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -9,7 +9,7 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ..inception import InceptionV3Weights @@ -22,8 +22,8 @@ ] -class QuantizedInceptionV3Weights(Weights): - ImageNet1K_FBGEMM_TFV1 = WeightEntry( +class QuantizedInceptionV3Weights(WeightsEnum): + ImageNet1K_FBGEMM_TFV1 = Weights( url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", transforms=partial(ImageNetEval, crop_size=299, resize_size=342), meta={ diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index 8c5c7fbf5b2..5d41de6a9ff 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -10,7 +10,7 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ..mobilenetv2 import MobileNetV2Weights @@ -23,8 +23,8 @@ ] -class QuantizedMobileNetV2Weights(Weights): - ImageNet1K_QNNPACK_RefV1 = WeightEntry( +class QuantizedMobileNetV2Weights(WeightsEnum): + ImageNet1K_QNNPACK_RefV1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index dd293fd4080..d62fb2b08a9 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -11,7 +11,7 @@ QuantizableMobileNetV3, _replace_relu, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf @@ -27,7 +27,7 @@ def _mobilenet_v3_model( inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, @@ -56,8 +56,8 @@ def _mobilenet_v3_model( return model -class QuantizedMobileNetV3LargeWeights(Weights): - ImageNet1K_QNNPACK_RefV1 = WeightEntry( +class QuantizedMobileNetV3LargeWeights(WeightsEnum): + ImageNet1K_QNNPACK_RefV1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 744b57d706a..095cfdbbd78 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -11,7 +11,7 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights @@ -31,7 +31,7 @@ def _resnet( block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], layers: List[int], - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, @@ -63,8 +63,8 @@ def _resnet( } -class QuantizedResNet18Weights(Weights): - ImageNet1K_FBGEMM_RefV1 = WeightEntry( +class QuantizedResNet18Weights(WeightsEnum): + ImageNet1K_FBGEMM_RefV1 = Weights( url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -77,8 +77,8 @@ class QuantizedResNet18Weights(Weights): ) -class QuantizedResNet50Weights(Weights): - ImageNet1K_FBGEMM_RefV1 = WeightEntry( +class QuantizedResNet50Weights(WeightsEnum): + ImageNet1K_FBGEMM_RefV1 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -89,7 +89,7 @@ class QuantizedResNet50Weights(Weights): }, default=False, ) - ImageNet1K_FBGEMM_RefV2 = WeightEntry( + ImageNet1K_FBGEMM_RefV2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -102,8 +102,8 @@ class QuantizedResNet50Weights(Weights): ) -class QuantizedResNeXt101_32x8dWeights(Weights): - ImageNet1K_FBGEMM_RefV1 = WeightEntry( +class QuantizedResNeXt101_32x8dWeights(WeightsEnum): + ImageNet1K_FBGEMM_RefV1 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -114,7 +114,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights): }, default=False, ) - ImageNet1K_FBGEMM_RefV2 = WeightEntry( + ImageNet1K_FBGEMM_RefV2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index d9aade357b0..dfeb3799eb4 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -9,7 +9,7 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights @@ -27,7 +27,7 @@ def _shufflenetv2( stages_repeats: List[int], stages_out_channels: List[int], - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, @@ -59,8 +59,8 @@ def _shufflenetv2( } -class QuantizedShuffleNetV2_x0_5Weights(Weights): - ImageNet1K_FBGEMM_Community = WeightEntry( +class QuantizedShuffleNetV2_x0_5Weights(WeightsEnum): + ImageNet1K_FBGEMM_Community = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -73,8 +73,8 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights): ) -class QuantizedShuffleNetV2_x1_0Weights(Weights): - ImageNet1K_FBGEMM_Community = WeightEntry( +class QuantizedShuffleNetV2_x1_0Weights(WeightsEnum): + ImageNet1K_FBGEMM_Community = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index b89d882cc0b..35e060afffc 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -6,7 +6,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.regnet import RegNet, BlockParams -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -48,7 +48,7 @@ def _regnet( block_params: BlockParams, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> RegNet: @@ -64,8 +64,8 @@ def _regnet( return model -class RegNet_y_400mfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_y_400mfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -78,8 +78,8 @@ class RegNet_y_400mfWeights(Weights): ) -class RegNet_y_800mfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_y_800mfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -92,8 +92,8 @@ class RegNet_y_800mfWeights(Weights): ) -class RegNet_y_1_6gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_y_1_6gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -106,8 +106,8 @@ class RegNet_y_1_6gfWeights(Weights): ) -class RegNet_y_3_2gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_y_3_2gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -120,8 +120,8 @@ class RegNet_y_3_2gfWeights(Weights): ) -class RegNet_y_8gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_y_8gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -134,8 +134,8 @@ class RegNet_y_8gfWeights(Weights): ) -class RegNet_y_16gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_y_16gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -148,8 +148,8 @@ class RegNet_y_16gfWeights(Weights): ) -class RegNet_y_32gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_y_32gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -162,8 +162,8 @@ class RegNet_y_32gfWeights(Weights): ) -class RegNet_x_400mfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_x_400mfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -176,8 +176,8 @@ class RegNet_x_400mfWeights(Weights): ) -class RegNet_x_800mfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_x_800mfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -190,8 +190,8 @@ class RegNet_x_800mfWeights(Weights): ) -class RegNet_x_1_6gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_x_1_6gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -204,8 +204,8 @@ class RegNet_x_1_6gfWeights(Weights): ) -class RegNet_x_3_2gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_x_3_2gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -218,8 +218,8 @@ class RegNet_x_3_2gfWeights(Weights): ) -class RegNet_x_8gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_x_8gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -232,8 +232,8 @@ class RegNet_x_8gfWeights(Weights): ) -class RegNet_x_16gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_x_16gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -246,8 +246,8 @@ class RegNet_x_16gfWeights(Weights): ) -class RegNet_x_32gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_x_32gfWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index 0ff4436bf63..b2e1d6c264b 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.resnet import BasicBlock, Bottleneck, ResNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -36,7 +36,7 @@ def _resnet( block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> ResNet: @@ -54,8 +54,8 @@ def _resnet( _COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} -class ResNet18Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet18Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/resnet18-f37072fd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -68,8 +68,8 @@ class ResNet18Weights(Weights): ) -class ResNet34Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet34Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/resnet34-b627a593.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -82,8 +82,8 @@ class ResNet34Weights(Weights): ) -class ResNet50Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet50Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -94,7 +94,7 @@ class ResNet50Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_RefV2 = Weights( url="https://download.pytorch.org/models/resnet50-f46c3f97.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -107,8 +107,8 @@ class ResNet50Weights(Weights): ) -class ResNet101Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet101Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/resnet101-63fe2227.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -119,7 +119,7 @@ class ResNet101Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_RefV2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -132,8 +132,8 @@ class ResNet101Weights(Weights): ) -class ResNet152Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet152Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/resnet152-394f9c45.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -144,7 +144,7 @@ class ResNet152Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_RefV2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -157,8 +157,8 @@ class ResNet152Weights(Weights): ) -class ResNeXt50_32x4dWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNeXt50_32x4dWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -169,7 +169,7 @@ class ResNeXt50_32x4dWeights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_RefV2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -182,8 +182,8 @@ class ResNeXt50_32x4dWeights(Weights): ) -class ResNeXt101_32x8dWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNeXt101_32x8dWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -194,7 +194,7 @@ class ResNeXt101_32x8dWeights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_RefV2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -207,8 +207,8 @@ class ResNeXt101_32x8dWeights(Weights): ) -class WideResNet50_2Weights(Weights): - ImageNet1K_Community = WeightEntry( +class WideResNet50_2Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -219,7 +219,7 @@ class WideResNet50_2Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_RefV2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -232,8 +232,8 @@ class WideResNet50_2Weights(Weights): ) -class WideResNet101_2Weights(Weights): - ImageNet1K_Community = WeightEntry( +class WideResNet101_2Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -244,7 +244,7 @@ class WideResNet101_2Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_RefV2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 93ff73f0032..c0e5c2c9e94 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large @@ -30,8 +30,8 @@ } -class DeepLabV3ResNet50Weights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class DeepLabV3ResNet50Weights(WeightsEnum): + CocoWithVocLabels_RefV1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -44,8 +44,8 @@ class DeepLabV3ResNet50Weights(Weights): ) -class DeepLabV3ResNet101Weights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class DeepLabV3ResNet101Weights(WeightsEnum): + CocoWithVocLabels_RefV1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -58,8 +58,8 @@ class DeepLabV3ResNet101Weights(Weights): ) -class DeepLabV3MobileNetV3LargeWeights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class DeepLabV3MobileNetV3LargeWeights(WeightsEnum): + CocoWithVocLabels_RefV1 = Weights( url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", transforms=partial(VocEval, resize_size=520), meta={ diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 138552d3aa0..0b04f0ed825 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.fcn import FCN, _fcn_resnet -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 @@ -20,8 +20,8 @@ } -class FCNResNet50Weights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class FCNResNet50Weights(WeightsEnum): + CocoWithVocLabels_RefV1 = Weights( url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -34,8 +34,8 @@ class FCNResNet50Weights(Weights): ) -class FCNResNet101Weights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class FCNResNet101Weights(WeightsEnum): + CocoWithVocLabels_RefV1 = Weights( url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", transforms=partial(VocEval, resize_size=520), meta={ diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 260093efc0f..0891a71528d 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large @@ -14,8 +14,8 @@ __all__ = ["LRASPP", "LRASPPMobileNetV3LargeWeights", "lraspp_mobilenet_v3_large"] -class LRASPPMobileNetV3LargeWeights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class LRASPPMobileNetV3LargeWeights(WeightsEnum): + CocoWithVocLabels_RefV1 = Weights( url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", transforms=partial(VocEval, resize_size=520), meta={ diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index 006f4027766..ff332c18ebc 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.shufflenetv2 import ShuffleNetV2 -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -24,7 +24,7 @@ def _shufflenetv2( - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, *args: Any, **kwargs: Any, @@ -48,8 +48,8 @@ def _shufflenetv2( } -class ShuffleNetV2_x0_5Weights(Weights): - ImageNet1K_Community = WeightEntry( +class ShuffleNetV2_x0_5Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -61,8 +61,8 @@ class ShuffleNetV2_x0_5Weights(Weights): ) -class ShuffleNetV2_x1_0Weights(Weights): - ImageNet1K_Community = WeightEntry( +class ShuffleNetV2_x1_0Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -74,11 +74,11 @@ class ShuffleNetV2_x1_0Weights(Weights): ) -class ShuffleNetV2_x1_5Weights(Weights): +class ShuffleNetV2_x1_5Weights(WeightsEnum): pass -class ShuffleNetV2_x2_0Weights(Weights): +class ShuffleNetV2_x2_0Weights(WeightsEnum): pass diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index 34144d15213..606818951ca 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.squeezenet import SqueezeNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -21,8 +21,8 @@ } -class SqueezeNet1_0Weights(Weights): - ImageNet1K_Community = WeightEntry( +class SqueezeNet1_0Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -34,8 +34,8 @@ class SqueezeNet1_0Weights(Weights): ) -class SqueezeNet1_1Weights(Weights): - ImageNet1K_Community = WeightEntry( +class SqueezeNet1_1Weights(WeightsEnum): + ImageNet1K_Community = Weights( url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index 8c5a58cb592..f7ce3a9e604 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -5,7 +5,7 @@ from torchvision.transforms.functional import InterpolationMode from ...models.vgg import VGG, make_layers, cfgs -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -31,7 +31,7 @@ ] -def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG: +def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) @@ -48,8 +48,8 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, } -class VGG11Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG11Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/vgg11-8a719046.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -61,8 +61,8 @@ class VGG11Weights(Weights): ) -class VGG11BNWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG11BNWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -74,8 +74,8 @@ class VGG11BNWeights(Weights): ) -class VGG13Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG13Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/vgg13-19584684.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -87,8 +87,8 @@ class VGG13Weights(Weights): ) -class VGG13BNWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG13BNWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -100,8 +100,8 @@ class VGG13BNWeights(Weights): ) -class VGG16Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG16Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/vgg16-397923af.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -114,7 +114,7 @@ class VGG16Weights(Weights): # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # same input standardization method as the paper. Only the `features` weights have proper values, those on the # `classifier` module are filled with nans. - ImageNet1K_Features = WeightEntry( + ImageNet1K_Features = Weights( url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", transforms=partial( ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0) @@ -131,8 +131,8 @@ class VGG16Weights(Weights): ) -class VGG16BNWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG16BNWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -144,8 +144,8 @@ class VGG16BNWeights(Weights): ) -class VGG19Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG19Weights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -157,8 +157,8 @@ class VGG19Weights(Weights): ) -class VGG19BNWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG19BNWeights(WeightsEnum): + ImageNet1K_RefV1 = Weights( url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 2dcfdfba2c0..6faabed493f 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -15,7 +15,7 @@ R2Plus1dStem, VideoResNet, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _KINETICS400_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param @@ -36,7 +36,7 @@ def _video_resnet( conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], layers: List[int], stem: Callable[..., nn.Module], - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> VideoResNet: @@ -59,8 +59,8 @@ def _video_resnet( } -class R3D_18Weights(Weights): - Kinetics400_RefV1 = WeightEntry( +class R3D_18Weights(WeightsEnum): + Kinetics400_RefV1 = Weights( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ @@ -72,8 +72,8 @@ class R3D_18Weights(Weights): ) -class MC3_18Weights(Weights): - Kinetics400_RefV1 = WeightEntry( +class MC3_18Weights(WeightsEnum): + Kinetics400_RefV1 = Weights( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ @@ -85,8 +85,8 @@ class MC3_18Weights(Weights): ) -class R2Plus1D_18Weights(Weights): - Kinetics400_RefV1 = WeightEntry( +class R2Plus1D_18Weights(WeightsEnum): + Kinetics400_RefV1 = Weights( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 987f3af1bb4..1f8673b6588 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -11,7 +11,7 @@ import torch.nn as nn from torch import Tensor -from ._api import Weights +from ._api import WeightsEnum from ._utils import _deprecated_param, _deprecated_positional @@ -231,22 +231,22 @@ def forward(self, x: torch.Tensor): return x -class VisionTransformer_B_16Weights(Weights): +class VisionTransformer_B_16Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_b_16 pass -class VisionTransformer_B_32Weights(Weights): +class VisionTransformer_B_32Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_b_32 pass -class VisionTransformer_L_16Weights(Weights): +class VisionTransformer_L_16Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_l_16 pass -class VisionTransformer_L_32Weights(Weights): +class VisionTransformer_L_32Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_l_32 pass @@ -257,7 +257,7 @@ def _vision_transformer( num_heads: int, hidden_dim: int, mlp_dim: int, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> VisionTransformer: From acef04d6ab400317cb276d2504d70702f1ef99fa Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 29 Nov 2021 12:56:10 +0000 Subject: [PATCH 2/4] Make enum values follow the naming convention `_V1`, `_V2` etc --- torchvision/prototype/models/alexnet.py | 4 +- torchvision/prototype/models/densenet.py | 16 +++--- .../prototype/models/detection/faster_rcnn.py | 20 +++---- .../models/detection/keypoint_rcnn.py | 12 ++-- .../prototype/models/detection/mask_rcnn.py | 8 +-- .../prototype/models/detection/retinanet.py | 8 +-- torchvision/prototype/models/detection/ssd.py | 4 +- .../prototype/models/detection/ssdlite.py | 6 +- torchvision/prototype/models/efficientnet.py | 32 +++++------ torchvision/prototype/models/googlenet.py | 4 +- torchvision/prototype/models/inception.py | 4 +- torchvision/prototype/models/mnasnet.py | 8 +-- torchvision/prototype/models/mobilenetv2.py | 4 +- torchvision/prototype/models/mobilenetv3.py | 10 ++-- .../models/quantization/googlenet.py | 8 +-- .../models/quantization/inception.py | 6 +- .../models/quantization/mobilenetv2.py | 6 +- .../models/quantization/mobilenetv3.py | 8 +-- .../prototype/models/quantization/resnet.py | 32 +++++------ .../models/quantization/shufflenetv2.py | 16 +++--- torchvision/prototype/models/regnet.py | 56 +++++++++---------- torchvision/prototype/models/resnet.py | 50 ++++++++--------- .../models/segmentation/deeplabv3.py | 18 +++--- .../prototype/models/segmentation/fcn.py | 12 ++-- .../prototype/models/segmentation/lraspp.py | 8 +-- torchvision/prototype/models/shufflenetv2.py | 8 +-- torchvision/prototype/models/squeezenet.py | 8 +-- torchvision/prototype/models/vgg.py | 32 +++++------ torchvision/prototype/models/video/resnet.py | 12 ++-- 29 files changed, 205 insertions(+), 215 deletions(-) diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index d6df916050d..dea77a0b4f4 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -14,7 +14,7 @@ class AlexNetWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -33,7 +33,7 @@ def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **k if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_V1) weights = AlexNetWeights.verify(weights) if weights is not None: diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index 18c091ff786..781f996c5b2 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -72,7 +72,7 @@ def _densenet( class DenseNet121Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -85,7 +85,7 @@ class DenseNet121Weights(WeightsEnum): class DenseNet161Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet161-8d451a50.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -98,7 +98,7 @@ class DenseNet161Weights(WeightsEnum): class DenseNet169Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -111,7 +111,7 @@ class DenseNet169Weights(WeightsEnum): class DenseNet201Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet201-c1103571.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -127,7 +127,7 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_V1) weights = DenseNet121Weights.verify(weights) return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) @@ -137,7 +137,7 @@ def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = T if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_V1) weights = DenseNet161Weights.verify(weights) return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) @@ -147,7 +147,7 @@ def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = T if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_V1) weights = DenseNet169Weights.verify(weights) return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) @@ -157,7 +157,7 @@ def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = T if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_V1) weights = DenseNet201Weights.verify(weights) return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 241466b8eba..66497a06ea3 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -37,7 +37,7 @@ class FasterRCNNResNet50FPNWeights(WeightsEnum): - Coco_RefV1 = Weights( + Coco_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", transforms=CocoEval, meta={ @@ -50,7 +50,7 @@ class FasterRCNNResNet50FPNWeights(WeightsEnum): class FasterRCNNMobileNetV3LargeFPNWeights(WeightsEnum): - Coco_RefV1 = Weights( + Coco_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", transforms=CocoEval, meta={ @@ -63,7 +63,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(WeightsEnum): class FasterRCNNMobileNetV3Large320FPNWeights(WeightsEnum): - Coco_RefV1 = Weights( + Coco_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", transforms=CocoEval, meta={ @@ -86,13 +86,13 @@ def fasterrcnn_resnet50_fpn( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNResNet50FPNWeights.Coco_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNResNet50FPNWeights.Coco_V1) weights = FasterRCNNResNet50FPNWeights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 ) weights_backbone = ResNet50Weights.verify(weights_backbone) @@ -112,7 +112,7 @@ def fasterrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1: + if weights == FasterRCNNResNet50FPNWeights.Coco_V1: overwrite_eps(model, 0.0) return model @@ -169,13 +169,13 @@ def fasterrcnn_mobilenet_v3_large_fpn( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3LargeFPNWeights.Coco_V1) weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 ) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) @@ -205,13 +205,13 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3Large320FPNWeights.Coco_V1) weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 ) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index 36fd955f73a..57cfb6522cc 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -26,7 +26,7 @@ class KeypointRCNNResNet50FPNWeights(WeightsEnum): - Coco_RefV1_Legacy = Weights( + Coco_Legacy = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", transforms=CocoEval, meta={ @@ -37,7 +37,7 @@ class KeypointRCNNResNet50FPNWeights(WeightsEnum): }, default=False, ) - Coco_RefV1 = Weights( + Coco_V1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", transforms=CocoEval, meta={ @@ -62,9 +62,9 @@ def keypointrcnn_resnet50_fpn( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1 + default_value = KeypointRCNNResNet50FPNWeights.Coco_V1 if kwargs["pretrained"] == "legacy": - default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy + default_value = KeypointRCNNResNet50FPNWeights.Coco_Legacy kwargs["pretrained"] = True weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) weights = KeypointRCNNResNet50FPNWeights.verify(weights) @@ -72,7 +72,7 @@ def keypointrcnn_resnet50_fpn( _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 ) weights_backbone = ResNet50Weights.verify(weights_backbone) @@ -96,7 +96,7 @@ def keypointrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1: + if weights == KeypointRCNNResNet50FPNWeights.Coco_V1: overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index be6742e4879..5d9113565d9 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -24,7 +24,7 @@ class MaskRCNNResNet50FPNWeights(WeightsEnum): - Coco_RefV1 = Weights( + Coco_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", transforms=CocoEval, meta={ @@ -49,13 +49,13 @@ def maskrcnn_resnet50_fpn( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNNResNet50FPNWeights.Coco_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNNResNet50FPNWeights.Coco_V1) weights = MaskRCNNResNet50FPNWeights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 ) weights_backbone = ResNet50Weights.verify(weights_backbone) @@ -75,7 +75,7 @@ def maskrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1: + if weights == MaskRCNNResNet50FPNWeights.Coco_V1: overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index 490a9b9e0c2..e74cdc7c1e9 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -25,7 +25,7 @@ class RetinaNetResNet50FPNWeights(WeightsEnum): - Coco_RefV1 = Weights( + Coco_V1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", transforms=CocoEval, meta={ @@ -49,13 +49,13 @@ def retinanet_resnet50_fpn( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNetResNet50FPNWeights.Coco_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNetResNet50FPNWeights.Coco_V1) weights = RetinaNetResNet50FPNWeights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 ) weights_backbone = ResNet50Weights.verify(weights_backbone) @@ -78,7 +78,7 @@ def retinanet_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == RetinaNetResNet50FPNWeights.Coco_RefV1: + if weights == RetinaNetResNet50FPNWeights.Coco_V1: overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index d2a51508366..3f1c3351912 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -23,7 +23,7 @@ class SSD300VGG16Weights(WeightsEnum): - Coco_RefV1 = Weights( + Coco_V1 = Weights( url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", transforms=CocoEval, meta={ @@ -48,7 +48,7 @@ def ssd300_vgg16( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300VGG16Weights.Coco_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300VGG16Weights.Coco_V1) weights = SSD300VGG16Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 5f597ba0951..d391dfc67bf 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -28,7 +28,7 @@ class SSDlite320MobileNetV3LargeFPNWeights(WeightsEnum): - Coco_RefV1 = Weights( + Coco_V1 = Weights( url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", transforms=CocoEval, meta={ @@ -54,13 +54,13 @@ def ssdlite320_mobilenet_v3_large( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", SSDlite320MobileNetV3LargeFPNWeights.Coco_V1) weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 ) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 16bf77cbe6f..a538d2272f5 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -70,7 +70,7 @@ def _efficientnet( class EfficientNetB0Weights(WeightsEnum): - ImageNet1K_TimmV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), meta={ @@ -84,7 +84,7 @@ class EfficientNetB0Weights(WeightsEnum): class EfficientNetB1Weights(WeightsEnum): - ImageNet1K_TimmV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), meta={ @@ -98,7 +98,7 @@ class EfficientNetB1Weights(WeightsEnum): class EfficientNetB2Weights(WeightsEnum): - ImageNet1K_TimmV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), meta={ @@ -112,7 +112,7 @@ class EfficientNetB2Weights(WeightsEnum): class EfficientNetB3Weights(WeightsEnum): - ImageNet1K_TimmV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), meta={ @@ -126,7 +126,7 @@ class EfficientNetB3Weights(WeightsEnum): class EfficientNetB4Weights(WeightsEnum): - ImageNet1K_TimmV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), meta={ @@ -140,7 +140,7 @@ class EfficientNetB4Weights(WeightsEnum): class EfficientNetB5Weights(WeightsEnum): - ImageNet1K_TFV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), meta={ @@ -154,7 +154,7 @@ class EfficientNetB5Weights(WeightsEnum): class EfficientNetB6Weights(WeightsEnum): - ImageNet1K_TFV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), meta={ @@ -168,7 +168,7 @@ class EfficientNetB6Weights(WeightsEnum): class EfficientNetB7Weights(WeightsEnum): - ImageNet1K_TFV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), meta={ @@ -187,7 +187,7 @@ def efficientnet_b0( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB0Weights.ImageNet1K_TimmV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB0Weights.ImageNet1K_V1) weights = EfficientNetB0Weights.verify(weights) return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) @@ -199,7 +199,7 @@ def efficientnet_b1( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB1Weights.ImageNet1K_TimmV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB1Weights.ImageNet1K_V1) weights = EfficientNetB1Weights.verify(weights) return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) @@ -211,7 +211,7 @@ def efficientnet_b2( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB2Weights.ImageNet1K_TimmV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB2Weights.ImageNet1K_V1) weights = EfficientNetB2Weights.verify(weights) return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) @@ -223,7 +223,7 @@ def efficientnet_b3( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB3Weights.ImageNet1K_TimmV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB3Weights.ImageNet1K_V1) weights = EfficientNetB3Weights.verify(weights) return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) @@ -235,7 +235,7 @@ def efficientnet_b4( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB4Weights.ImageNet1K_TimmV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB4Weights.ImageNet1K_V1) weights = EfficientNetB4Weights.verify(weights) return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) @@ -247,7 +247,7 @@ def efficientnet_b5( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB5Weights.ImageNet1K_TFV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB5Weights.ImageNet1K_V1) weights = EfficientNetB5Weights.verify(weights) return _efficientnet( @@ -267,7 +267,7 @@ def efficientnet_b6( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB6Weights.ImageNet1K_TFV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB6Weights.ImageNet1K_V1) weights = EfficientNetB6Weights.verify(weights) return _efficientnet( @@ -287,7 +287,7 @@ def efficientnet_b7( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB7Weights.ImageNet1K_TFV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB7Weights.ImageNet1K_V1) weights = EfficientNetB7Weights.verify(weights) return _efficientnet( diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 260259af903..bc0eb151b94 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -15,7 +15,7 @@ class GoogLeNetWeights(WeightsEnum): - ImageNet1K_TFV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/googlenet-1378be20.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -34,7 +34,7 @@ def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNetWeights.ImageNet1K_TFV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNetWeights.ImageNet1K_V1) weights = GoogLeNetWeights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index d82d459b1d4..d19caa5895a 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -14,7 +14,7 @@ class InceptionV3Weights(WeightsEnum): - ImageNet1K_TFV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", transforms=partial(ImageNetEval, crop_size=299, resize_size=342), meta={ @@ -33,7 +33,7 @@ def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", InceptionV3Weights.ImageNet1K_TFV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", InceptionV3Weights.ImageNet1K_V1) weights = InceptionV3Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", True) diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index a109b2ef50b..96784682dee 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -32,7 +32,7 @@ class MNASNet0_5Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -50,7 +50,7 @@ class MNASNet0_75Weights(WeightsEnum): class MNASNet1_0Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -83,7 +83,7 @@ def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = Tru if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5Weights.ImageNet1K_V1) weights = MNASNet0_5Weights.verify(weights) return _mnasnet(0.5, weights, progress, **kwargs) @@ -103,7 +103,7 @@ def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = Tru if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0Weights.ImageNet1K_V1) weights = MNASNet1_0Weights.verify(weights) return _mnasnet(1.0, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index d0649f7a5bc..8d4520bfd5a 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -14,7 +14,7 @@ class MobileNetV2Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -33,7 +33,7 @@ def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV2Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV2Weights.ImageNet1K_V1) weights = MobileNetV2Weights.verify(weights) if weights is not None: diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index 6b2c312d5cb..d69c7a11705 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -45,7 +45,7 @@ def _mobilenet_v3( class MobileNetV3LargeWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -56,7 +56,7 @@ class MobileNetV3LargeWeights(WeightsEnum): }, default=False, ) - ImageNet1K_RefV2 = Weights( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -70,7 +70,7 @@ class MobileNetV3LargeWeights(WeightsEnum): class MobileNetV3SmallWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -89,7 +89,7 @@ def mobilenet_v3_large( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3LargeWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3LargeWeights.ImageNet1K_V1) weights = MobileNetV3LargeWeights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) @@ -102,7 +102,7 @@ def mobilenet_v3_small( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3SmallWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3SmallWeights.ImageNet1K_V1) weights = MobileNetV3SmallWeights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index 8340606b2d5..18c980d2d02 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -24,7 +24,7 @@ class QuantizedGoogLeNetWeights(WeightsEnum): - ImageNet1K_FBGEMM_TFV1 = Weights( + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -34,7 +34,7 @@ class QuantizedGoogLeNetWeights(WeightsEnum): "backend": "fbgemm", "quantization": "ptq", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": GoogLeNetWeights.ImageNet1K_TFV1, + "unquantized": GoogLeNetWeights.ImageNet1K_V1, "acc@1": 69.826, "acc@5": 89.404, }, @@ -51,9 +51,7 @@ def googlenet( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = ( - QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1 - ) + default_value = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_V1 if quantize else GoogLeNetWeights.ImageNet1K_V1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: weights = QuantizedGoogLeNetWeights.verify(weights) diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index 01c310e3271..ea0db4d76ed 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -23,7 +23,7 @@ class QuantizedInceptionV3Weights(WeightsEnum): - ImageNet1K_FBGEMM_TFV1 = Weights( + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", transforms=partial(ImageNetEval, crop_size=299, resize_size=342), meta={ @@ -33,7 +33,7 @@ class QuantizedInceptionV3Weights(WeightsEnum): "backend": "fbgemm", "quantization": "ptq", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": InceptionV3Weights.ImageNet1K_TFV1, + "unquantized": InceptionV3Weights.ImageNet1K_V1, "acc@1": 77.176, "acc@5": 93.354, }, @@ -51,7 +51,7 @@ def inception_v3( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1 + QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_V1 if quantize else InceptionV3Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index 5d41de6a9ff..3c56d1f91ed 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -24,7 +24,7 @@ class QuantizedMobileNetV2Weights(WeightsEnum): - ImageNet1K_QNNPACK_RefV1 = Weights( + ImageNet1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -34,7 +34,7 @@ class QuantizedMobileNetV2Weights(WeightsEnum): "backend": "qnnpack", "quantization": "qat", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", - "unquantized": MobileNetV2Weights.ImageNet1K_RefV1, + "unquantized": MobileNetV2Weights.ImageNet1K_V1, "acc@1": 71.658, "acc@5": 90.150, }, @@ -52,7 +52,7 @@ def mobilenet_v2( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1 if quantize else MobileNetV2Weights.ImageNet1K_RefV1 + QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_V1 if quantize else MobileNetV2Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index d62fb2b08a9..127a96c8547 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -57,7 +57,7 @@ def _mobilenet_v3_model( class QuantizedMobileNetV3LargeWeights(WeightsEnum): - ImageNet1K_QNNPACK_RefV1 = Weights( + ImageNet1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -67,7 +67,7 @@ class QuantizedMobileNetV3LargeWeights(WeightsEnum): "backend": "qnnpack", "quantization": "qat", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", - "unquantized": MobileNetV3LargeWeights.ImageNet1K_RefV1, + "unquantized": MobileNetV3LargeWeights.ImageNet1K_V1, "acc@1": 73.004, "acc@5": 90.858, }, @@ -85,9 +85,9 @@ def mobilenet_v3_large( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1 + QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_V1 if quantize - else MobileNetV3LargeWeights.ImageNet1K_RefV1 + else MobileNetV3LargeWeights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 095cfdbbd78..d4739ee5497 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -64,12 +64,12 @@ def _resnet( class QuantizedResNet18Weights(WeightsEnum): - ImageNet1K_FBGEMM_RefV1 = Weights( + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ResNet18Weights.ImageNet1K_RefV1, + "unquantized": ResNet18Weights.ImageNet1K_V1, "acc@1": 69.494, "acc@5": 88.882, }, @@ -78,23 +78,23 @@ class QuantizedResNet18Weights(WeightsEnum): class QuantizedResNet50Weights(WeightsEnum): - ImageNet1K_FBGEMM_RefV1 = Weights( + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ResNet50Weights.ImageNet1K_RefV1, + "unquantized": ResNet50Weights.ImageNet1K_V1, "acc@1": 75.920, "acc@5": 92.814, }, default=False, ) - ImageNet1K_FBGEMM_RefV2 = Weights( + ImageNet1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "unquantized": ResNet50Weights.ImageNet1K_RefV2, + "unquantized": ResNet50Weights.ImageNet1K_V2, "acc@1": 80.282, "acc@5": 94.976, }, @@ -103,23 +103,23 @@ class QuantizedResNet50Weights(WeightsEnum): class QuantizedResNeXt101_32x8dWeights(WeightsEnum): - ImageNet1K_FBGEMM_RefV1 = Weights( + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1, + "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_V1, "acc@1": 78.986, "acc@5": 94.480, }, default=False, ) - ImageNet1K_FBGEMM_RefV2 = Weights( + ImageNet1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV2, + "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_V2, "acc@1": 82.574, "acc@5": 96.132, }, @@ -136,9 +136,7 @@ def resnet18( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = ( - QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1 - ) + default_value = QuantizedResNet18Weights.ImageNet1K_FBGEMM_V1 if quantize else ResNet18Weights.ImageNet1K_V1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: weights = QuantizedResNet18Weights.verify(weights) @@ -157,9 +155,7 @@ def resnet50( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = ( - QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1 - ) + default_value = QuantizedResNet50Weights.ImageNet1K_FBGEMM_V1 if quantize else ResNet50Weights.ImageNet1K_V1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: weights = QuantizedResNet50Weights.verify(weights) @@ -179,9 +175,7 @@ def resnext101_32x8d( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1 - if quantize - else ResNeXt101_32x8dWeights.ImageNet1K_RefV1 + QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNeXt101_32x8dWeights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index dfeb3799eb4..5c5bed5fee6 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -60,12 +60,12 @@ def _shufflenetv2( class QuantizedShuffleNetV2_x0_5Weights(WeightsEnum): - ImageNet1K_FBGEMM_Community = Weights( + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_Community, + "unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_V1, "acc@1": 57.972, "acc@5": 79.780, }, @@ -74,12 +74,12 @@ class QuantizedShuffleNetV2_x0_5Weights(WeightsEnum): class QuantizedShuffleNetV2_x1_0Weights(WeightsEnum): - ImageNet1K_FBGEMM_Community = Weights( + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_Community, + "unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_V1, "acc@1": 68.360, "acc@5": 87.582, }, @@ -97,9 +97,9 @@ def shufflenet_v2_x0_5( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community + QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_V1 if quantize - else ShuffleNetV2_x0_5Weights.ImageNet1K_Community + else ShuffleNetV2_x0_5Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: @@ -120,9 +120,9 @@ def shufflenet_v2_x1_0( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community + QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_V1 if quantize - else ShuffleNetV2_x1_0Weights.ImageNet1K_Community + else ShuffleNetV2_x1_0Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index 35e060afffc..436433ca185 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -65,7 +65,7 @@ def _regnet( class RegNet_y_400mfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -79,7 +79,7 @@ class RegNet_y_400mfWeights(WeightsEnum): class RegNet_y_800mfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -93,7 +93,7 @@ class RegNet_y_800mfWeights(WeightsEnum): class RegNet_y_1_6gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -107,7 +107,7 @@ class RegNet_y_1_6gfWeights(WeightsEnum): class RegNet_y_3_2gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -121,7 +121,7 @@ class RegNet_y_3_2gfWeights(WeightsEnum): class RegNet_y_8gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -135,7 +135,7 @@ class RegNet_y_8gfWeights(WeightsEnum): class RegNet_y_16gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -149,7 +149,7 @@ class RegNet_y_16gfWeights(WeightsEnum): class RegNet_y_32gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -163,7 +163,7 @@ class RegNet_y_32gfWeights(WeightsEnum): class RegNet_x_400mfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -177,7 +177,7 @@ class RegNet_x_400mfWeights(WeightsEnum): class RegNet_x_800mfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -191,7 +191,7 @@ class RegNet_x_800mfWeights(WeightsEnum): class RegNet_x_1_6gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -205,7 +205,7 @@ class RegNet_x_1_6gfWeights(WeightsEnum): class RegNet_x_3_2gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -219,7 +219,7 @@ class RegNet_x_3_2gfWeights(WeightsEnum): class RegNet_x_8gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -233,7 +233,7 @@ class RegNet_x_8gfWeights(WeightsEnum): class RegNet_x_16gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -247,7 +247,7 @@ class RegNet_x_16gfWeights(WeightsEnum): class RegNet_x_32gfWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -264,7 +264,7 @@ def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bo if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_400mfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_400mfWeights.ImageNet1K_V1) weights = RegNet_y_400mfWeights.verify(weights) params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) @@ -275,7 +275,7 @@ def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bo if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_800mfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_800mfWeights.ImageNet1K_V1) weights = RegNet_y_800mfWeights.verify(weights) params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) @@ -286,7 +286,7 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_1_6gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_1_6gfWeights.ImageNet1K_V1) weights = RegNet_y_1_6gfWeights.verify(weights) params = BlockParams.from_init_params( @@ -299,7 +299,7 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_3_2gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_3_2gfWeights.ImageNet1K_V1) weights = RegNet_y_3_2gfWeights.verify(weights) params = BlockParams.from_init_params( @@ -312,7 +312,7 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_8gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_8gfWeights.ImageNet1K_V1) weights = RegNet_y_8gfWeights.verify(weights) params = BlockParams.from_init_params( @@ -325,7 +325,7 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_16gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_16gfWeights.ImageNet1K_V1) weights = RegNet_y_16gfWeights.verify(weights) params = BlockParams.from_init_params( @@ -338,7 +338,7 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_32gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_32gfWeights.ImageNet1K_V1) weights = RegNet_y_32gfWeights.verify(weights) params = BlockParams.from_init_params( @@ -351,7 +351,7 @@ def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bo if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_400mfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_400mfWeights.ImageNet1K_V1) weights = RegNet_x_400mfWeights.verify(weights) params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) @@ -362,7 +362,7 @@ def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bo if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_800mfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_800mfWeights.ImageNet1K_V1) weights = RegNet_x_800mfWeights.verify(weights) params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) @@ -373,7 +373,7 @@ def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bo if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_1_6gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_1_6gfWeights.ImageNet1K_V1) weights = RegNet_x_1_6gfWeights.verify(weights) params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) @@ -384,7 +384,7 @@ def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bo if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_3_2gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_3_2gfWeights.ImageNet1K_V1) weights = RegNet_x_3_2gfWeights.verify(weights) params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) @@ -395,7 +395,7 @@ def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_8gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_8gfWeights.ImageNet1K_V1) weights = RegNet_x_8gfWeights.verify(weights) params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) @@ -406,7 +406,7 @@ def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_16gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_16gfWeights.ImageNet1K_V1) weights = RegNet_x_16gfWeights.verify(weights) params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) @@ -417,7 +417,7 @@ def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_32gfWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_32gfWeights.ImageNet1K_V1) weights = RegNet_x_32gfWeights.verify(weights) params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index b2e1d6c264b..47117f383d3 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -55,7 +55,7 @@ def _resnet( class ResNet18Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet18-f37072fd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -69,7 +69,7 @@ class ResNet18Weights(WeightsEnum): class ResNet34Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet34-b627a593.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -83,7 +83,7 @@ class ResNet34Weights(WeightsEnum): class ResNet50Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -94,7 +94,7 @@ class ResNet50Weights(WeightsEnum): }, default=False, ) - ImageNet1K_RefV2 = Weights( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet50-f46c3f97.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -108,7 +108,7 @@ class ResNet50Weights(WeightsEnum): class ResNet101Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet101-63fe2227.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -119,7 +119,7 @@ class ResNet101Weights(WeightsEnum): }, default=False, ) - ImageNet1K_RefV2 = Weights( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -133,7 +133,7 @@ class ResNet101Weights(WeightsEnum): class ResNet152Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet152-394f9c45.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -144,7 +144,7 @@ class ResNet152Weights(WeightsEnum): }, default=False, ) - ImageNet1K_RefV2 = Weights( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -158,7 +158,7 @@ class ResNet152Weights(WeightsEnum): class ResNeXt50_32x4dWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -169,7 +169,7 @@ class ResNeXt50_32x4dWeights(WeightsEnum): }, default=False, ) - ImageNet1K_RefV2 = Weights( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -183,7 +183,7 @@ class ResNeXt50_32x4dWeights(WeightsEnum): class ResNeXt101_32x8dWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -194,7 +194,7 @@ class ResNeXt101_32x8dWeights(WeightsEnum): }, default=False, ) - ImageNet1K_RefV2 = Weights( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -208,7 +208,7 @@ class ResNeXt101_32x8dWeights(WeightsEnum): class WideResNet50_2Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -219,7 +219,7 @@ class WideResNet50_2Weights(WeightsEnum): }, default=False, ) - ImageNet1K_RefV2 = Weights( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -233,7 +233,7 @@ class WideResNet50_2Weights(WeightsEnum): class WideResNet101_2Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -244,7 +244,7 @@ class WideResNet101_2Weights(WeightsEnum): }, default=False, ) - ImageNet1K_RefV2 = Weights( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -261,7 +261,7 @@ def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, * if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18Weights.ImageNet1K_V1) weights = ResNet18Weights.verify(weights) return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) @@ -271,7 +271,7 @@ def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, * if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34Weights.ImageNet1K_V1) weights = ResNet34Weights.verify(weights) return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) @@ -281,7 +281,7 @@ def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, * if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50Weights.ImageNet1K_V1) weights = ResNet50Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) @@ -291,7 +291,7 @@ def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101Weights.ImageNet1K_V1) weights = ResNet101Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) @@ -301,7 +301,7 @@ def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152Weights.ImageNet1K_V1) weights = ResNet152Weights.verify(weights) return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) @@ -311,7 +311,7 @@ def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32x4dWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32x4dWeights.ImageNet1K_V1) weights = ResNeXt50_32x4dWeights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) @@ -323,7 +323,7 @@ def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32x8dWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32x8dWeights.ImageNet1K_V1) weights = ResNeXt101_32x8dWeights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) @@ -335,7 +335,7 @@ def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: b if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet50_2Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet50_2Weights.ImageNet1K_V1) weights = WideResNet50_2Weights.verify(weights) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) @@ -346,7 +346,7 @@ def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet101_2Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet101_2Weights.ImageNet1K_V1) weights = WideResNet101_2Weights.verify(weights) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index c0e5c2c9e94..80e0285a293 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -31,7 +31,7 @@ class DeepLabV3ResNet50Weights(WeightsEnum): - CocoWithVocLabels_RefV1 = Weights( + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -45,7 +45,7 @@ class DeepLabV3ResNet50Weights(WeightsEnum): class DeepLabV3ResNet101Weights(WeightsEnum): - CocoWithVocLabels_RefV1 = Weights( + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -59,7 +59,7 @@ class DeepLabV3ResNet101Weights(WeightsEnum): class DeepLabV3MobileNetV3LargeWeights(WeightsEnum): - CocoWithVocLabels_RefV1 = Weights( + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -83,13 +83,13 @@ def deeplabv3_resnet50( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet50Weights.CocoWithVocLabels_V1) weights = DeepLabV3ResNet50Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 ) weights_backbone = ResNet50Weights.verify(weights_backbone) @@ -120,13 +120,13 @@ def deeplabv3_resnet101( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet101Weights.CocoWithVocLabels_V1) weights = DeepLabV3ResNet101Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_V1 ) weights_backbone = ResNet101Weights.verify(weights_backbone) @@ -158,14 +158,14 @@ def deeplabv3_mobilenet_v3_large( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param( - kwargs, "pretrained", "weights", DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 + kwargs, "pretrained", "weights", DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_V1 ) weights = DeepLabV3MobileNetV3LargeWeights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 ) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 0b04f0ed825..d0444d24293 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -21,7 +21,7 @@ class FCNResNet50Weights(WeightsEnum): - CocoWithVocLabels_RefV1 = Weights( + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -35,7 +35,7 @@ class FCNResNet50Weights(WeightsEnum): class FCNResNet101Weights(WeightsEnum): - CocoWithVocLabels_RefV1 = Weights( + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -59,13 +59,13 @@ def fcn_resnet50( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet50Weights.CocoWithVocLabels_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet50Weights.CocoWithVocLabels_V1) weights = FCNResNet50Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 ) weights_backbone = ResNet50Weights.verify(weights_backbone) @@ -96,13 +96,13 @@ def fcn_resnet101( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet101Weights.CocoWithVocLabels_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet101Weights.CocoWithVocLabels_V1) weights = FCNResNet101Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_V1 ) weights_backbone = ResNet101Weights.verify(weights_backbone) diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 0891a71528d..1fb989e2f62 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -15,7 +15,7 @@ class LRASPPMobileNetV3LargeWeights(WeightsEnum): - CocoWithVocLabels_RefV1 = Weights( + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -42,15 +42,13 @@ def lraspp_mobilenet_v3_large( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param( - kwargs, "pretrained", "weights", LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1 - ) + weights = _deprecated_param(kwargs, "pretrained", "weights", LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_V1) weights = LRASPPMobileNetV3LargeWeights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 ) weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index ff332c18ebc..6a0eb790632 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -49,7 +49,7 @@ def _shufflenetv2( class ShuffleNetV2_x0_5Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -62,7 +62,7 @@ class ShuffleNetV2_x0_5Weights(WeightsEnum): class ShuffleNetV2_x1_0Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -88,7 +88,7 @@ def shufflenet_v2_x0_5( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x0_5Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x0_5Weights.ImageNet1K_V1) weights = ShuffleNetV2_x0_5Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) @@ -100,7 +100,7 @@ def shufflenet_v2_x1_0( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x1_0Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x1_0Weights.ImageNet1K_V1) weights = ShuffleNetV2_x1_0Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index 606818951ca..05cf6f811ff 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -22,7 +22,7 @@ class SqueezeNet1_0Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -35,7 +35,7 @@ class SqueezeNet1_0Weights(WeightsEnum): class SqueezeNet1_1Weights(WeightsEnum): - ImageNet1K_Community = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -51,7 +51,7 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0Weights.ImageNet1K_V1) weights = SqueezeNet1_0Weights.verify(weights) if weights is not None: @@ -69,7 +69,7 @@ def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1Weights.ImageNet1K_Community) + weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1Weights.ImageNet1K_V1) weights = SqueezeNet1_1Weights.verify(weights) if weights is not None: diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index f7ce3a9e604..9fdcbf4053e 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -49,7 +49,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b class VGG11Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11-8a719046.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -62,7 +62,7 @@ class VGG11Weights(WeightsEnum): class VGG11BNWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -75,7 +75,7 @@ class VGG11BNWeights(WeightsEnum): class VGG13Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13-19584684.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -88,7 +88,7 @@ class VGG13Weights(WeightsEnum): class VGG13BNWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -101,7 +101,7 @@ class VGG13BNWeights(WeightsEnum): class VGG16Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16-397923af.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -132,7 +132,7 @@ class VGG16Weights(WeightsEnum): class VGG16BNWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -145,7 +145,7 @@ class VGG16BNWeights(WeightsEnum): class VGG19Weights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -158,7 +158,7 @@ class VGG19Weights(WeightsEnum): class VGG19BNWeights(WeightsEnum): - ImageNet1K_RefV1 = Weights( + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -174,7 +174,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11Weights.ImageNet1K_V1) weights = VGG11Weights.verify(weights) return _vgg("A", False, weights, progress, **kwargs) @@ -184,7 +184,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, ** if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11BNWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11BNWeights.ImageNet1K_V1) weights = VGG11BNWeights.verify(weights) return _vgg("A", True, weights, progress, **kwargs) @@ -194,7 +194,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13Weights.ImageNet1K_V1) weights = VGG13Weights.verify(weights) return _vgg("B", False, weights, progress, **kwargs) @@ -204,7 +204,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, ** if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13BNWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13BNWeights.ImageNet1K_V1) weights = VGG13BNWeights.verify(weights) return _vgg("B", True, weights, progress, **kwargs) @@ -214,7 +214,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16Weights.ImageNet1K_V1) weights = VGG16Weights.verify(weights) return _vgg("D", False, weights, progress, **kwargs) @@ -224,7 +224,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, ** if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16BNWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16BNWeights.ImageNet1K_V1) weights = VGG16BNWeights.verify(weights) return _vgg("D", True, weights, progress, **kwargs) @@ -234,7 +234,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19Weights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19Weights.ImageNet1K_V1) weights = VGG19Weights.verify(weights) return _vgg("E", False, weights, progress, **kwargs) @@ -244,7 +244,7 @@ def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, ** if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19BNWeights.ImageNet1K_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19BNWeights.ImageNet1K_V1) weights = VGG19BNWeights.verify(weights) return _vgg("E", True, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 6faabed493f..32fc66d7d2a 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -60,7 +60,7 @@ def _video_resnet( class R3D_18Weights(WeightsEnum): - Kinetics400_RefV1 = Weights( + Kinetics400_V1 = Weights( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ @@ -73,7 +73,7 @@ class R3D_18Weights(WeightsEnum): class MC3_18Weights(WeightsEnum): - Kinetics400_RefV1 = Weights( + Kinetics400_V1 = Weights( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ @@ -86,7 +86,7 @@ class MC3_18Weights(WeightsEnum): class R2Plus1D_18Weights(WeightsEnum): - Kinetics400_RefV1 = Weights( + Kinetics400_V1 = Weights( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ @@ -102,7 +102,7 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18Weights.Kinetics400_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18Weights.Kinetics400_V1) weights = R3D_18Weights.verify(weights) return _video_resnet( @@ -120,7 +120,7 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18Weights.Kinetics400_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18Weights.Kinetics400_V1) weights = MC3_18Weights.verify(weights) return _video_resnet( @@ -138,7 +138,7 @@ def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = T if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18Weights.Kinetics400_RefV1) + weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18Weights.Kinetics400_V1) weights = R2Plus1D_18Weights.verify(weights) return _video_resnet( From 5d3500e11a46423dfcc5947745a591dad799e367 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 29 Nov 2021 14:28:18 +0000 Subject: [PATCH 3/4] Cleanup the Enum class naming conventions. --- test/test_prototype_models.py | 14 +- torchvision/prototype/models/alexnet.py | 10 +- torchvision/prototype/models/densenet.py | 40 ++--- .../prototype/models/detection/faster_rcnn.py | 60 ++++---- .../models/detection/keypoint_rcnn.py | 22 +-- .../prototype/models/detection/mask_rcnn.py | 20 +-- .../prototype/models/detection/retinanet.py | 20 +-- torchvision/prototype/models/detection/ssd.py | 18 +-- .../prototype/models/detection/ssdlite.py | 18 +-- torchvision/prototype/models/efficientnet.py | 80 +++++----- torchvision/prototype/models/googlenet.py | 10 +- torchvision/prototype/models/inception.py | 10 +- torchvision/prototype/models/mnasnet.py | 36 ++--- torchvision/prototype/models/mobilenetv2.py | 10 +- torchvision/prototype/models/mobilenetv3.py | 20 +-- .../models/quantization/googlenet.py | 16 +- .../models/quantization/inception.py | 16 +- .../models/quantization/mobilenetv2.py | 16 +- .../models/quantization/mobilenetv3.py | 18 +-- .../prototype/models/quantization/resnet.py | 50 ++++--- .../models/quantization/shufflenetv2.py | 34 ++--- torchvision/prototype/models/regnet.py | 140 +++++++++--------- torchvision/prototype/models/resnet.py | 94 ++++++------ .../models/segmentation/deeplabv3.py | 52 +++---- .../prototype/models/segmentation/fcn.py | 32 ++-- .../prototype/models/segmentation/lraspp.py | 20 +-- torchvision/prototype/models/shufflenetv2.py | 36 ++--- torchvision/prototype/models/squeezenet.py | 18 +-- torchvision/prototype/models/vgg.py | 80 +++++----- torchvision/prototype/models/video/resnet.py | 30 ++-- .../prototype/models/vision_transformer.py | 48 +++--- 31 files changed, 547 insertions(+), 541 deletions(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 81d7e30863a..14522094b41 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -37,13 +37,17 @@ def get_models_with_module_names(module): @pytest.mark.parametrize( "model_fn, name, weight", [ - (models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1), - (models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2), - (models.quantization.resnet50, "default", models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV2), + (models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1), + (models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2), ( models.quantization.resnet50, - "ImageNet1K_FBGEMM_RefV1", - models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1, + "default", + models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2, + ), + ( + models.quantization.resnet50, + "ImageNet1K_FBGEMM_V1", + models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1, ), ], ) diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index dea77a0b4f4..b45ca1e7085 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -10,10 +10,10 @@ from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["AlexNet", "AlexNetWeights", "alexnet"] +__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] -class AlexNetWeights(WeightsEnum): +class AlexNet_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -29,12 +29,12 @@ class AlexNetWeights(WeightsEnum): ) -def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: +def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_V1) - weights = AlexNetWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1) + weights = AlexNet_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index 781f996c5b2..e779a2cd239 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -14,10 +14,10 @@ __all__ = [ "DenseNet", - "DenseNet121Weights", - "DenseNet161Weights", - "DenseNet169Weights", - "DenseNet201Weights", + "DenseNet121_Weights", + "DenseNet161_Weights", + "DenseNet169_Weights", + "DenseNet201_Weights", "densenet121", "densenet161", "densenet169", @@ -71,7 +71,7 @@ def _densenet( } -class DenseNet121Weights(WeightsEnum): +class DenseNet121_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -84,7 +84,7 @@ class DenseNet121Weights(WeightsEnum): ) -class DenseNet161Weights(WeightsEnum): +class DenseNet161_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet161-8d451a50.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -97,7 +97,7 @@ class DenseNet161Weights(WeightsEnum): ) -class DenseNet169Weights(WeightsEnum): +class DenseNet169_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -110,7 +110,7 @@ class DenseNet169Weights(WeightsEnum): ) -class DenseNet201Weights(WeightsEnum): +class DenseNet201_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet201-c1103571.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -123,41 +123,41 @@ class DenseNet201Weights(WeightsEnum): ) -def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: +def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_V1) - weights = DenseNet121Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1) + weights = DenseNet121_Weights.verify(weights) return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) -def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: +def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_V1) - weights = DenseNet161Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1) + weights = DenseNet161_Weights.verify(weights) return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) -def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: +def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_V1) - weights = DenseNet169Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1) + weights = DenseNet169_Weights.verify(weights) return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) -def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: +def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_V1) - weights = DenseNet201Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1) + weights = DenseNet201_Weights.verify(weights) return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 66497a06ea3..c83aaf222fb 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -15,15 +15,15 @@ from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large -from ..resnet import ResNet50Weights, resnet50 +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from ..resnet import ResNet50_Weights, resnet50 __all__ = [ "FasterRCNN", - "FasterRCNNResNet50FPNWeights", - "FasterRCNNMobileNetV3LargeFPNWeights", - "FasterRCNNMobileNetV3Large320FPNWeights", + "FasterRCNN_ResNet50_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn", @@ -36,7 +36,7 @@ } -class FasterRCNNResNet50FPNWeights(WeightsEnum): +class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): Coco_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", transforms=CocoEval, @@ -49,7 +49,7 @@ class FasterRCNNResNet50FPNWeights(WeightsEnum): ) -class FasterRCNNMobileNetV3LargeFPNWeights(WeightsEnum): +class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): Coco_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", transforms=CocoEval, @@ -62,7 +62,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(WeightsEnum): ) -class FasterRCNNMobileNetV3Large320FPNWeights(WeightsEnum): +class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): Coco_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", transforms=CocoEval, @@ -76,25 +76,25 @@ class FasterRCNNMobileNetV3Large320FPNWeights(WeightsEnum): def fasterrcnn_resnet50_fpn( - weights: Optional[FasterRCNNResNet50FPNWeights] = None, + weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNResNet50FPNWeights.Coco_V1) - weights = FasterRCNNResNet50FPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_ResNet50_FPN_Weights.Coco_V1) + weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -112,17 +112,17 @@ def fasterrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == FasterRCNNResNet50FPNWeights.Coco_V1: + if weights == FasterRCNN_ResNet50_FPN_Weights.Coco_V1: overwrite_eps(model, 0.0) return model def _fasterrcnn_mobilenet_v3_large_fpn( - weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]], + weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], progress: bool, num_classes: Optional[int], - weights_backbone: Optional[MobileNetV3LargeWeights], + weights_backbone: Optional[MobileNet_V3_Large_Weights], trainable_backbone_layers: Optional[int], **kwargs: Any, ) -> FasterRCNN: @@ -159,25 +159,25 @@ def _fasterrcnn_mobilenet_v3_large_fpn( def fasterrcnn_mobilenet_v3_large_fpn( - weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None, + weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3LargeFPNWeights.Coco_V1) - weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1) + weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) defaults = { "rpn_score_thresh": 0.05, @@ -195,25 +195,27 @@ def fasterrcnn_mobilenet_v3_large_fpn( def fasterrcnn_mobilenet_v3_large_320_fpn( - weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None, + weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3Large320FPNWeights.Coco_V1) - weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) + weights = _deprecated_param( + kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1 + ) + weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) defaults = { "min_size": 320, diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index 57cfb6522cc..85250ac2a33 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -12,12 +12,12 @@ from .._api import WeightsEnum, Weights from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..resnet import ResNet50Weights, resnet50 +from ..resnet import ResNet50_Weights, resnet50 __all__ = [ "KeypointRCNN", - "KeypointRCNNResNet50FPNWeights", + "KeypointRCNN_ResNet50_FPN_Weights", "keypointrcnn_resnet50_fpn", ] @@ -25,7 +25,7 @@ _COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES} -class KeypointRCNNResNet50FPNWeights(WeightsEnum): +class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): Coco_Legacy = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", transforms=CocoEval, @@ -51,30 +51,30 @@ class KeypointRCNNResNet50FPNWeights(WeightsEnum): def keypointrcnn_resnet50_fpn( - weights: Optional[KeypointRCNNResNet50FPNWeights] = None, + weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, num_keypoints: Optional[int] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> KeypointRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = KeypointRCNNResNet50FPNWeights.Coco_V1 + default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_V1 if kwargs["pretrained"] == "legacy": - default_value = KeypointRCNNResNet50FPNWeights.Coco_Legacy + default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy kwargs["pretrained"] = True weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) - weights = KeypointRCNNResNet50FPNWeights.verify(weights) + weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -96,7 +96,7 @@ def keypointrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == KeypointRCNNResNet50FPNWeights.Coco_V1: + if weights == KeypointRCNN_ResNet50_FPN_Weights.Coco_V1: overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index 5d9113565d9..ea7ab4f5fc7 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -13,17 +13,17 @@ from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..resnet import ResNet50Weights, resnet50 +from ..resnet import ResNet50_Weights, resnet50 __all__ = [ "MaskRCNN", - "MaskRCNNResNet50FPNWeights", + "MaskRCNN_ResNet50_FPN_Weights", "maskrcnn_resnet50_fpn", ] -class MaskRCNNResNet50FPNWeights(WeightsEnum): +class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): Coco_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", transforms=CocoEval, @@ -39,25 +39,25 @@ class MaskRCNNResNet50FPNWeights(WeightsEnum): def maskrcnn_resnet50_fpn( - weights: Optional[MaskRCNNResNet50FPNWeights] = None, + weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> MaskRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNNResNet50FPNWeights.Coco_V1) - weights = MaskRCNNResNet50FPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNN_ResNet50_FPN_Weights.Coco_V1) + weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -75,7 +75,7 @@ def maskrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == MaskRCNNResNet50FPNWeights.Coco_V1: + if weights == MaskRCNN_ResNet50_FPN_Weights.Coco_V1: overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index e74cdc7c1e9..d442c79d5b6 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -14,17 +14,17 @@ from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..resnet import ResNet50Weights, resnet50 +from ..resnet import ResNet50_Weights, resnet50 __all__ = [ "RetinaNet", - "RetinaNetResNet50FPNWeights", + "RetinaNet_ResNet50_FPN_Weights", "retinanet_resnet50_fpn", ] -class RetinaNetResNet50FPNWeights(WeightsEnum): +class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): Coco_V1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", transforms=CocoEval, @@ -39,25 +39,25 @@ class RetinaNetResNet50FPNWeights(WeightsEnum): def retinanet_resnet50_fpn( - weights: Optional[RetinaNetResNet50FPNWeights] = None, + weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> RetinaNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNetResNet50FPNWeights.Coco_V1) - weights = RetinaNetResNet50FPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNet_ResNet50_FPN_Weights.Coco_V1) + weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -78,7 +78,7 @@ def retinanet_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == RetinaNetResNet50FPNWeights.Coco_V1: + if weights == RetinaNet_ResNet50_FPN_Weights.Coco_V1: overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 3f1c3351912..37f5c2a6944 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -13,16 +13,16 @@ from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..vgg import VGG16Weights, vgg16 +from ..vgg import VGG16_Weights, vgg16 __all__ = [ - "SSD300VGG16Weights", + "SSD300_VGG16_Weights", "ssd300_vgg16", ] -class SSD300VGG16Weights(WeightsEnum): +class SSD300_VGG16_Weights(WeightsEnum): Coco_V1 = Weights( url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", transforms=CocoEval, @@ -38,25 +38,25 @@ class SSD300VGG16Weights(WeightsEnum): def ssd300_vgg16( - weights: Optional[SSD300VGG16Weights] = None, + weights: Optional[SSD300_VGG16_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[VGG16Weights] = None, + weights_backbone: Optional[VGG16_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> SSD: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300VGG16Weights.Coco_V1) - weights = SSD300VGG16Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300_VGG16_Weights.Coco_V1) + weights = SSD300_VGG16_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", VGG16Weights.ImageNet1K_Features + kwargs, "pretrained_backbone", "weights_backbone", VGG16_Weights.ImageNet1K_Features ) - weights_backbone = VGG16Weights.verify(weights_backbone) + weights_backbone = VGG16_Weights.verify(weights_backbone) if "size" in kwargs: warnings.warn("The size of the model is already fixed; ignoring the parameter.") diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index d391dfc67bf..309362f2f11 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -18,16 +18,16 @@ from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large __all__ = [ - "SSDlite320MobileNetV3LargeFPNWeights", + "SSDLite320_MobileNet_V3_Large_Weights", "ssdlite320_mobilenet_v3_large", ] -class SSDlite320MobileNetV3LargeFPNWeights(WeightsEnum): +class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): Coco_V1 = Weights( url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", transforms=CocoEval, @@ -43,10 +43,10 @@ class SSDlite320MobileNetV3LargeFPNWeights(WeightsEnum): def ssdlite320_mobilenet_v3_large( - weights: Optional[SSDlite320MobileNetV3LargeFPNWeights] = None, + weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, trainable_backbone_layers: Optional[int] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any, @@ -54,15 +54,15 @@ def ssdlite320_mobilenet_v3_large( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SSDlite320MobileNetV3LargeFPNWeights.Coco_V1) - weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1) + weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) if "size" in kwargs: warnings.warn("The size of the model is already fixed; ignoring the parameter.") diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index a538d2272f5..74ca6ccc71d 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -13,14 +13,14 @@ __all__ = [ "EfficientNet", - "EfficientNetB0Weights", - "EfficientNetB1Weights", - "EfficientNetB2Weights", - "EfficientNetB3Weights", - "EfficientNetB4Weights", - "EfficientNetB5Weights", - "EfficientNetB6Weights", - "EfficientNetB7Weights", + "EfficientNet_B0_Weights", + "EfficientNet_B1_Weights", + "EfficientNet_B2_Weights", + "EfficientNet_B3_Weights", + "EfficientNet_B4_Weights", + "EfficientNet_B5_Weights", + "EfficientNet_B6_Weights", + "EfficientNet_B7_Weights", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", @@ -69,7 +69,7 @@ def _efficientnet( } -class EfficientNetB0Weights(WeightsEnum): +class EfficientNet_B0_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), @@ -83,7 +83,7 @@ class EfficientNetB0Weights(WeightsEnum): ) -class EfficientNetB1Weights(WeightsEnum): +class EfficientNet_B1_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), @@ -97,7 +97,7 @@ class EfficientNetB1Weights(WeightsEnum): ) -class EfficientNetB2Weights(WeightsEnum): +class EfficientNet_B2_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), @@ -111,7 +111,7 @@ class EfficientNetB2Weights(WeightsEnum): ) -class EfficientNetB3Weights(WeightsEnum): +class EfficientNet_B3_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), @@ -125,7 +125,7 @@ class EfficientNetB3Weights(WeightsEnum): ) -class EfficientNetB4Weights(WeightsEnum): +class EfficientNet_B4_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), @@ -139,7 +139,7 @@ class EfficientNetB4Weights(WeightsEnum): ) -class EfficientNetB5Weights(WeightsEnum): +class EfficientNet_B5_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), @@ -153,7 +153,7 @@ class EfficientNetB5Weights(WeightsEnum): ) -class EfficientNetB6Weights(WeightsEnum): +class EfficientNet_B6_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), @@ -167,7 +167,7 @@ class EfficientNetB6Weights(WeightsEnum): ) -class EfficientNetB7Weights(WeightsEnum): +class EfficientNet_B7_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), @@ -182,73 +182,73 @@ class EfficientNetB7Weights(WeightsEnum): def efficientnet_b0( - weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB0Weights.ImageNet1K_V1) - weights = EfficientNetB0Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B0_Weights.ImageNet1K_V1) + weights = EfficientNet_B0_Weights.verify(weights) return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) def efficientnet_b1( - weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB1Weights.ImageNet1K_V1) - weights = EfficientNetB1Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B1_Weights.ImageNet1K_V1) + weights = EfficientNet_B1_Weights.verify(weights) return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) def efficientnet_b2( - weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB2Weights.ImageNet1K_V1) - weights = EfficientNetB2Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B2_Weights.ImageNet1K_V1) + weights = EfficientNet_B2_Weights.verify(weights) return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) def efficientnet_b3( - weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB3Weights.ImageNet1K_V1) - weights = EfficientNetB3Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B3_Weights.ImageNet1K_V1) + weights = EfficientNet_B3_Weights.verify(weights) return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) def efficientnet_b4( - weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB4Weights.ImageNet1K_V1) - weights = EfficientNetB4Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B4_Weights.ImageNet1K_V1) + weights = EfficientNet_B4_Weights.verify(weights) return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) def efficientnet_b5( - weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB5Weights.ImageNet1K_V1) - weights = EfficientNetB5Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B5_Weights.ImageNet1K_V1) + weights = EfficientNet_B5_Weights.verify(weights) return _efficientnet( width_mult=1.6, @@ -262,13 +262,13 @@ def efficientnet_b5( def efficientnet_b6( - weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB6Weights.ImageNet1K_V1) - weights = EfficientNetB6Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B6_Weights.ImageNet1K_V1) + weights = EfficientNet_B6_Weights.verify(weights) return _efficientnet( width_mult=1.8, @@ -282,13 +282,13 @@ def efficientnet_b6( def efficientnet_b7( - weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB7Weights.ImageNet1K_V1) - weights = EfficientNetB7Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B7_Weights.ImageNet1K_V1) + weights = EfficientNet_B7_Weights.verify(weights) return _efficientnet( width_mult=2.0, diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index bc0eb151b94..352c49d1a2e 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -11,10 +11,10 @@ from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"] +__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] -class GoogLeNetWeights(WeightsEnum): +class GoogLeNet_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/googlenet-1378be20.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -30,12 +30,12 @@ class GoogLeNetWeights(WeightsEnum): ) -def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: +def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNetWeights.ImageNet1K_V1) - weights = GoogLeNetWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNet_Weights.ImageNet1K_V1) + weights = GoogLeNet_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index d19caa5895a..9837b1fc4a6 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -10,10 +10,10 @@ from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"] +__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] -class InceptionV3Weights(WeightsEnum): +class Inception_V3_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", transforms=partial(ImageNetEval, crop_size=299, resize_size=342), @@ -29,12 +29,12 @@ class InceptionV3Weights(WeightsEnum): ) -def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: +def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", InceptionV3Weights.ImageNet1K_V1) - weights = InceptionV3Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", Inception_V3_Weights.ImageNet1K_V1) + weights = Inception_V3_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", True) if weights is not None: diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index 96784682dee..73aaea0beca 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -12,10 +12,10 @@ __all__ = [ "MNASNet", - "MNASNet0_5Weights", - "MNASNet0_75Weights", - "MNASNet1_0Weights", - "MNASNet1_3Weights", + "MNASNet0_5_Weights", + "MNASNet0_75_Weights", + "MNASNet1_0_Weights", + "MNASNet1_3_Weights", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", @@ -31,7 +31,7 @@ } -class MNASNet0_5Weights(WeightsEnum): +class MNASNet0_5_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -44,12 +44,12 @@ class MNASNet0_5Weights(WeightsEnum): ) -class MNASNet0_75Weights(WeightsEnum): +class MNASNet0_75_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in mnasnet0_75 pass -class MNASNet1_0Weights(WeightsEnum): +class MNASNet1_0_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -62,7 +62,7 @@ class MNASNet1_0Weights(WeightsEnum): ) -class MNASNet1_3Weights(WeightsEnum): +class MNASNet1_3_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in mnasnet1_3 pass @@ -79,41 +79,41 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa return model -def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: +def mnasnet0_5(weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5Weights.ImageNet1K_V1) - weights = MNASNet0_5Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5_Weights.ImageNet1K_V1) + weights = MNASNet0_5_Weights.verify(weights) return _mnasnet(0.5, weights, progress, **kwargs) -def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: +def mnasnet0_75(weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = MNASNet0_75Weights.verify(weights) + weights = MNASNet0_75_Weights.verify(weights) return _mnasnet(0.75, weights, progress, **kwargs) -def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: +def mnasnet1_0(weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0Weights.ImageNet1K_V1) - weights = MNASNet1_0Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0_Weights.ImageNet1K_V1) + weights = MNASNet1_0_Weights.verify(weights) return _mnasnet(1.0, weights, progress, **kwargs) -def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: +def mnasnet1_3(weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = MNASNet1_3Weights.verify(weights) + weights = MNASNet1_3_Weights.verify(weights) return _mnasnet(1.3, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 8d4520bfd5a..0c0f80d081a 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -10,10 +10,10 @@ from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"] +__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] -class MobileNetV2Weights(WeightsEnum): +class MobileNet_V2_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -29,12 +29,12 @@ class MobileNetV2Weights(WeightsEnum): ) -def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: +def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV2Weights.ImageNet1K_V1) - weights = MobileNetV2Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V2_Weights.ImageNet1K_V1) + weights = MobileNet_V2_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index d69c7a11705..e014fb5acb2 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -12,8 +12,8 @@ __all__ = [ "MobileNetV3", - "MobileNetV3LargeWeights", - "MobileNetV3SmallWeights", + "MobileNet_V3_Large_Weights", + "MobileNet_V3_Small_Weights", "mobilenet_v3_large", "mobilenet_v3_small", ] @@ -44,7 +44,7 @@ def _mobilenet_v3( } -class MobileNetV3LargeWeights(WeightsEnum): +class MobileNet_V3_Large_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -69,7 +69,7 @@ class MobileNetV3LargeWeights(WeightsEnum): ) -class MobileNetV3SmallWeights(WeightsEnum): +class MobileNet_V3_Small_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -84,26 +84,26 @@ class MobileNetV3SmallWeights(WeightsEnum): def mobilenet_v3_large( - weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any + weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any ) -> MobileNetV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3LargeWeights.ImageNet1K_V1) - weights = MobileNetV3LargeWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Large_Weights.ImageNet1K_V1) + weights = MobileNet_V3_Large_Weights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) def mobilenet_v3_small( - weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any + weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any ) -> MobileNetV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3SmallWeights.ImageNet1K_V1) - weights = MobileNetV3SmallWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Small_Weights.ImageNet1K_V1) + weights = MobileNet_V3_Small_Weights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index 18c980d2d02..3d26fd7d607 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -13,17 +13,17 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..googlenet import GoogLeNetWeights +from ..googlenet import GoogLeNet_Weights __all__ = [ "QuantizableGoogLeNet", - "QuantizedGoogLeNetWeights", + "GoogLeNet_QuantizedWeights", "googlenet", ] -class QuantizedGoogLeNetWeights(WeightsEnum): +class GoogLeNet_QuantizedWeights(WeightsEnum): ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -34,7 +34,7 @@ class QuantizedGoogLeNetWeights(WeightsEnum): "backend": "fbgemm", "quantization": "ptq", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": GoogLeNetWeights.ImageNet1K_V1, + "unquantized": GoogLeNet_Weights.ImageNet1K_V1, "acc@1": 69.826, "acc@5": 89.404, }, @@ -43,7 +43,7 @@ class QuantizedGoogLeNetWeights(WeightsEnum): def googlenet( - weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None, + weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -51,12 +51,12 @@ def googlenet( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_V1 if quantize else GoogLeNetWeights.ImageNet1K_V1 + default_value = GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else GoogLeNet_Weights.ImageNet1K_V1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedGoogLeNetWeights.verify(weights) + weights = GoogLeNet_QuantizedWeights.verify(weights) else: - weights = GoogLeNetWeights.verify(weights) + weights = GoogLeNet_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index ea0db4d76ed..ff779076df6 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -12,17 +12,17 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..inception import InceptionV3Weights +from ..inception import Inception_V3_Weights __all__ = [ "QuantizableInception3", - "QuantizedInceptionV3Weights", + "Inception_V3_QuantizedWeights", "inception_v3", ] -class QuantizedInceptionV3Weights(WeightsEnum): +class Inception_V3_QuantizedWeights(WeightsEnum): ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", transforms=partial(ImageNetEval, crop_size=299, resize_size=342), @@ -33,7 +33,7 @@ class QuantizedInceptionV3Weights(WeightsEnum): "backend": "fbgemm", "quantization": "ptq", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": InceptionV3Weights.ImageNet1K_V1, + "unquantized": Inception_V3_Weights.ImageNet1K_V1, "acc@1": 77.176, "acc@5": 93.354, }, @@ -42,7 +42,7 @@ class QuantizedInceptionV3Weights(WeightsEnum): def inception_v3( - weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None, + weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -51,13 +51,13 @@ def inception_v3( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_V1 if quantize else InceptionV3Weights.ImageNet1K_V1 + Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else Inception_V3_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedInceptionV3Weights.verify(weights) + weights = Inception_V3_QuantizedWeights.verify(weights) else: - weights = InceptionV3Weights.verify(weights) + weights = Inception_V3_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index 3c56d1f91ed..c5afd731fad 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -13,17 +13,17 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..mobilenetv2 import MobileNetV2Weights +from ..mobilenetv2 import MobileNet_V2_Weights __all__ = [ "QuantizableMobileNetV2", - "QuantizedMobileNetV2Weights", + "MobileNet_V2_QuantizedWeights", "mobilenet_v2", ] -class QuantizedMobileNetV2Weights(WeightsEnum): +class MobileNet_V2_QuantizedWeights(WeightsEnum): ImageNet1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -34,7 +34,7 @@ class QuantizedMobileNetV2Weights(WeightsEnum): "backend": "qnnpack", "quantization": "qat", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", - "unquantized": MobileNetV2Weights.ImageNet1K_V1, + "unquantized": MobileNet_V2_Weights.ImageNet1K_V1, "acc@1": 71.658, "acc@5": 90.150, }, @@ -43,7 +43,7 @@ class QuantizedMobileNetV2Weights(WeightsEnum): def mobilenet_v2( - weights: Optional[Union[QuantizedMobileNetV2Weights, MobileNetV2Weights]] = None, + weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -52,13 +52,13 @@ def mobilenet_v2( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_V1 if quantize else MobileNetV2Weights.ImageNet1K_V1 + MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1 if quantize else MobileNet_V2_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedMobileNetV2Weights.verify(weights) + weights = MobileNet_V2_QuantizedWeights.verify(weights) else: - weights = MobileNetV2Weights.verify(weights) + weights = MobileNet_V2_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index 127a96c8547..a29e3f44697 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -14,12 +14,12 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf +from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf __all__ = [ "QuantizableMobileNetV3", - "QuantizedMobileNetV3LargeWeights", + "MobileNet_V3_Large_QuantizedWeights", "mobilenet_v3_large", ] @@ -56,7 +56,7 @@ def _mobilenet_v3_model( return model -class QuantizedMobileNetV3LargeWeights(WeightsEnum): +class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): ImageNet1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -67,7 +67,7 @@ class QuantizedMobileNetV3LargeWeights(WeightsEnum): "backend": "qnnpack", "quantization": "qat", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", - "unquantized": MobileNetV3LargeWeights.ImageNet1K_V1, + "unquantized": MobileNet_V3_Large_Weights.ImageNet1K_V1, "acc@1": 73.004, "acc@5": 90.858, }, @@ -76,7 +76,7 @@ class QuantizedMobileNetV3LargeWeights(WeightsEnum): def mobilenet_v3_large( - weights: Optional[Union[QuantizedMobileNetV3LargeWeights, MobileNetV3LargeWeights]] = None, + weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -85,15 +85,15 @@ def mobilenet_v3_large( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_V1 + MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1 if quantize - else MobileNetV3LargeWeights.ImageNet1K_V1 + else MobileNet_V3_Large_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedMobileNetV3LargeWeights.verify(weights) + weights = MobileNet_V3_Large_QuantizedWeights.verify(weights) else: - weights = MobileNetV3LargeWeights.verify(weights) + weights = MobileNet_V3_Large_Weights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index d4739ee5497..0de4eb5557b 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -14,14 +14,14 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights +from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights __all__ = [ "QuantizableResNet", - "QuantizedResNet18Weights", - "QuantizedResNet50Weights", - "QuantizedResNeXt101_32x8dWeights", + "ResNet18_QuantizedWeights", + "ResNet50_QuantizedWeights", + "ResNeXt101_32X8D_QuantizedWeights", "resnet18", "resnet50", "resnext101_32x8d", @@ -63,13 +63,13 @@ def _resnet( } -class QuantizedResNet18Weights(WeightsEnum): +class ResNet18_QuantizedWeights(WeightsEnum): ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ResNet18Weights.ImageNet1K_V1, + "unquantized": ResNet18_Weights.ImageNet1K_V1, "acc@1": 69.494, "acc@5": 88.882, }, @@ -77,13 +77,13 @@ class QuantizedResNet18Weights(WeightsEnum): ) -class QuantizedResNet50Weights(WeightsEnum): +class ResNet50_QuantizedWeights(WeightsEnum): ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ResNet50Weights.ImageNet1K_V1, + "unquantized": ResNet50_Weights.ImageNet1K_V1, "acc@1": 75.920, "acc@5": 92.814, }, @@ -94,7 +94,7 @@ class QuantizedResNet50Weights(WeightsEnum): transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "unquantized": ResNet50Weights.ImageNet1K_V2, + "unquantized": ResNet50_Weights.ImageNet1K_V2, "acc@1": 80.282, "acc@5": 94.976, }, @@ -102,13 +102,13 @@ class QuantizedResNet50Weights(WeightsEnum): ) -class QuantizedResNeXt101_32x8dWeights(WeightsEnum): +class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_V1, + "unquantized": ResNeXt101_32X8D_Weights.ImageNet1K_V1, "acc@1": 78.986, "acc@5": 94.480, }, @@ -119,7 +119,7 @@ class QuantizedResNeXt101_32x8dWeights(WeightsEnum): transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_V2, + "unquantized": ResNeXt101_32X8D_Weights.ImageNet1K_V2, "acc@1": 82.574, "acc@5": 96.132, }, @@ -128,7 +128,7 @@ class QuantizedResNeXt101_32x8dWeights(WeightsEnum): def resnet18( - weights: Optional[Union[QuantizedResNet18Weights, ResNet18Weights]] = None, + weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -136,18 +136,18 @@ def resnet18( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = QuantizedResNet18Weights.ImageNet1K_FBGEMM_V1 if quantize else ResNet18Weights.ImageNet1K_V1 + default_value = ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet18_Weights.ImageNet1K_V1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedResNet18Weights.verify(weights) + weights = ResNet18_QuantizedWeights.verify(weights) else: - weights = ResNet18Weights.verify(weights) + weights = ResNet18_Weights.verify(weights) return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) def resnet50( - weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None, + weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -155,18 +155,18 @@ def resnet50( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = QuantizedResNet50Weights.ImageNet1K_FBGEMM_V1 if quantize else ResNet50Weights.ImageNet1K_V1 + default_value = ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet50_Weights.ImageNet1K_V1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedResNet50Weights.verify(weights) + weights = ResNet50_QuantizedWeights.verify(weights) else: - weights = ResNet50Weights.verify(weights) + weights = ResNet50_Weights.verify(weights) return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) def resnext101_32x8d( - weights: Optional[Union[QuantizedResNeXt101_32x8dWeights, ResNeXt101_32x8dWeights]] = None, + weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -175,13 +175,15 @@ def resnext101_32x8d( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNeXt101_32x8dWeights.ImageNet1K_V1 + ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1 + if quantize + else ResNeXt101_32X8D_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedResNeXt101_32x8dWeights.verify(weights) + weights = ResNeXt101_32X8D_QuantizedWeights.verify(weights) else: - weights = ResNeXt101_32x8dWeights.verify(weights) + weights = ResNeXt101_32X8D_Weights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "width_per_group", 8) diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index 5c5bed5fee6..6677983a1d9 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -12,13 +12,13 @@ from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights +from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights __all__ = [ "QuantizableShuffleNetV2", - "QuantizedShuffleNetV2_x0_5Weights", - "QuantizedShuffleNetV2_x1_0Weights", + "ShuffleNet_V2_X0_5_QuantizedWeights", + "ShuffleNet_V2_X1_0_QuantizedWeights", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", ] @@ -59,13 +59,13 @@ def _shufflenetv2( } -class QuantizedShuffleNetV2_x0_5Weights(WeightsEnum): +class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_V1, + "unquantized": ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1, "acc@1": 57.972, "acc@5": 79.780, }, @@ -73,13 +73,13 @@ class QuantizedShuffleNetV2_x0_5Weights(WeightsEnum): ) -class QuantizedShuffleNetV2_x1_0Weights(WeightsEnum): +class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_V1, + "unquantized": ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1, "acc@1": 68.360, "acc@5": 87.582, }, @@ -88,7 +88,7 @@ class QuantizedShuffleNetV2_x1_0Weights(WeightsEnum): def shufflenet_v2_x0_5( - weights: Optional[Union[QuantizedShuffleNetV2_x0_5Weights, ShuffleNetV2_x0_5Weights]] = None, + weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -97,21 +97,21 @@ def shufflenet_v2_x0_5( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_V1 + ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize - else ShuffleNetV2_x0_5Weights.ImageNet1K_V1 + else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedShuffleNetV2_x0_5Weights.verify(weights) + weights = ShuffleNet_V2_X0_5_QuantizedWeights.verify(weights) else: - weights = ShuffleNetV2_x0_5Weights.verify(weights) + weights = ShuffleNet_V2_X0_5_Weights.verify(weights) return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs) def shufflenet_v2_x1_0( - weights: Optional[Union[QuantizedShuffleNetV2_x1_0Weights, ShuffleNetV2_x1_0Weights]] = None, + weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -120,14 +120,14 @@ def shufflenet_v2_x1_0( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_V1 + ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize - else ShuffleNetV2_x1_0Weights.ImageNet1K_V1 + else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedShuffleNetV2_x1_0Weights.verify(weights) + weights = ShuffleNet_V2_X1_0_QuantizedWeights.verify(weights) else: - weights = ShuffleNetV2_x1_0Weights.verify(weights) + weights = ShuffleNet_V2_X1_0_Weights.verify(weights) return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index 436433ca185..1e12ae7bbd2 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -13,20 +13,20 @@ __all__ = [ "RegNet", - "RegNet_y_400mfWeights", - "RegNet_y_800mfWeights", - "RegNet_y_1_6gfWeights", - "RegNet_y_3_2gfWeights", - "RegNet_y_8gfWeights", - "RegNet_y_16gfWeights", - "RegNet_y_32gfWeights", - "RegNet_x_400mfWeights", - "RegNet_x_800mfWeights", - "RegNet_x_1_6gfWeights", - "RegNet_x_3_2gfWeights", - "RegNet_x_8gfWeights", - "RegNet_x_16gfWeights", - "RegNet_x_32gfWeights", + "RegNet_Y_400MF_Weights", + "RegNet_Y_800MF_Weights", + "RegNet_Y_1_6GF_Weights", + "RegNet_Y_3_2GF_Weights", + "RegNet_Y_8GF_Weights", + "RegNet_Y_16GF_Weights", + "RegNet_Y_32GF_Weights", + "RegNet_X_400MF_Weights", + "RegNet_X_800MF_Weights", + "RegNet_X_1_6GF_Weights", + "RegNet_X_3_2GF_Weights", + "RegNet_X_8GF_Weights", + "RegNet_X_16GF_Weights", + "RegNet_X_32GF_Weights", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", @@ -64,7 +64,7 @@ def _regnet( return model -class RegNet_y_400mfWeights(WeightsEnum): +class RegNet_Y_400MF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -78,7 +78,7 @@ class RegNet_y_400mfWeights(WeightsEnum): ) -class RegNet_y_800mfWeights(WeightsEnum): +class RegNet_Y_800MF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -92,7 +92,7 @@ class RegNet_y_800mfWeights(WeightsEnum): ) -class RegNet_y_1_6gfWeights(WeightsEnum): +class RegNet_Y_1_6GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -106,7 +106,7 @@ class RegNet_y_1_6gfWeights(WeightsEnum): ) -class RegNet_y_3_2gfWeights(WeightsEnum): +class RegNet_Y_3_2GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -120,7 +120,7 @@ class RegNet_y_3_2gfWeights(WeightsEnum): ) -class RegNet_y_8gfWeights(WeightsEnum): +class RegNet_Y_8GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -134,7 +134,7 @@ class RegNet_y_8gfWeights(WeightsEnum): ) -class RegNet_y_16gfWeights(WeightsEnum): +class RegNet_Y_16GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -148,7 +148,7 @@ class RegNet_y_16gfWeights(WeightsEnum): ) -class RegNet_y_32gfWeights(WeightsEnum): +class RegNet_Y_32GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -162,7 +162,7 @@ class RegNet_y_32gfWeights(WeightsEnum): ) -class RegNet_x_400mfWeights(WeightsEnum): +class RegNet_X_400MF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -176,7 +176,7 @@ class RegNet_x_400mfWeights(WeightsEnum): ) -class RegNet_x_800mfWeights(WeightsEnum): +class RegNet_X_800MF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -190,7 +190,7 @@ class RegNet_x_800mfWeights(WeightsEnum): ) -class RegNet_x_1_6gfWeights(WeightsEnum): +class RegNet_X_1_6GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -204,7 +204,7 @@ class RegNet_x_1_6gfWeights(WeightsEnum): ) -class RegNet_x_3_2gfWeights(WeightsEnum): +class RegNet_X_3_2GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -218,7 +218,7 @@ class RegNet_x_3_2gfWeights(WeightsEnum): ) -class RegNet_x_8gfWeights(WeightsEnum): +class RegNet_X_8GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -232,7 +232,7 @@ class RegNet_x_8gfWeights(WeightsEnum): ) -class RegNet_x_16gfWeights(WeightsEnum): +class RegNet_X_16GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -246,7 +246,7 @@ class RegNet_x_16gfWeights(WeightsEnum): ) -class RegNet_x_32gfWeights(WeightsEnum): +class RegNet_X_32GF_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -260,34 +260,34 @@ class RegNet_x_32gfWeights(WeightsEnum): ) -def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_400mfWeights.ImageNet1K_V1) - weights = RegNet_y_400mfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_400MF_Weights.ImageNet1K_V1) + weights = RegNet_Y_400MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_800mf(weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_800mfWeights.ImageNet1K_V1) - weights = RegNet_y_800mfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_800MF_Weights.ImageNet1K_V1) + weights = RegNet_Y_800MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_1_6gfWeights.ImageNet1K_V1) - weights = RegNet_y_1_6gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_1_6GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_1_6GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs @@ -295,12 +295,12 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo return _regnet(params, weights, progress, **kwargs) -def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_3_2gfWeights.ImageNet1K_V1) - weights = RegNet_y_3_2gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_3_2GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_3_2GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs @@ -308,12 +308,12 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo return _regnet(params, weights, progress, **kwargs) -def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_8gfWeights.ImageNet1K_V1) - weights = RegNet_y_8gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_8GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_8GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs @@ -321,12 +321,12 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = return _regnet(params, weights, progress, **kwargs) -def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_16gfWeights.ImageNet1K_V1) - weights = RegNet_y_16gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_16GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_16GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs @@ -334,12 +334,12 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool return _regnet(params, weights, progress, **kwargs) -def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_32gfWeights.ImageNet1K_V1) - weights = RegNet_y_32gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_32GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_32GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs @@ -347,78 +347,78 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool return _regnet(params, weights, progress, **kwargs) -def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_400mf(weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_400mfWeights.ImageNet1K_V1) - weights = RegNet_x_400mfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_400MF_Weights.ImageNet1K_V1) + weights = RegNet_X_400MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_800mf(weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_800mfWeights.ImageNet1K_V1) - weights = RegNet_x_800mfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_800MF_Weights.ImageNet1K_V1) + weights = RegNet_X_800MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_1_6gf(weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_1_6gfWeights.ImageNet1K_V1) - weights = RegNet_x_1_6gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_1_6GF_Weights.ImageNet1K_V1) + weights = RegNet_X_1_6GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_3_2gf(weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_3_2gfWeights.ImageNet1K_V1) - weights = RegNet_x_3_2gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_3_2GF_Weights.ImageNet1K_V1) + weights = RegNet_X_3_2GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_8gf(weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_8gfWeights.ImageNet1K_V1) - weights = RegNet_x_8gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_8GF_Weights.ImageNet1K_V1) + weights = RegNet_X_8GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_16gf(weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_16gfWeights.ImageNet1K_V1) - weights = RegNet_x_16gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_16GF_Weights.ImageNet1K_V1) + weights = RegNet_X_16GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_32gf(weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_32gfWeights.ImageNet1K_V1) - weights = RegNet_x_32gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_32GF_Weights.ImageNet1K_V1) + weights = RegNet_X_32GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) return _regnet(params, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index 47117f383d3..e213864acbe 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -12,15 +12,15 @@ __all__ = [ "ResNet", - "ResNet18Weights", - "ResNet34Weights", - "ResNet50Weights", - "ResNet101Weights", - "ResNet152Weights", - "ResNeXt50_32x4dWeights", - "ResNeXt101_32x8dWeights", - "WideResNet50_2Weights", - "WideResNet101_2Weights", + "ResNet18_Weights", + "ResNet34_Weights", + "ResNet50_Weights", + "ResNet101_Weights", + "ResNet152_Weights", + "ResNeXt50_32X4D_Weights", + "ResNeXt101_32X8D_Weights", + "Wide_ResNet50_2_Weights", + "Wide_ResNet101_2_Weights", "resnet18", "resnet34", "resnet50", @@ -54,7 +54,7 @@ def _resnet( _COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} -class ResNet18Weights(WeightsEnum): +class ResNet18_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet18-f37072fd.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -68,7 +68,7 @@ class ResNet18Weights(WeightsEnum): ) -class ResNet34Weights(WeightsEnum): +class ResNet34_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet34-b627a593.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -82,7 +82,7 @@ class ResNet34Weights(WeightsEnum): ) -class ResNet50Weights(WeightsEnum): +class ResNet50_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -107,7 +107,7 @@ class ResNet50Weights(WeightsEnum): ) -class ResNet101Weights(WeightsEnum): +class ResNet101_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet101-63fe2227.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -132,7 +132,7 @@ class ResNet101Weights(WeightsEnum): ) -class ResNet152Weights(WeightsEnum): +class ResNet152_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet152-394f9c45.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -157,7 +157,7 @@ class ResNet152Weights(WeightsEnum): ) -class ResNeXt50_32x4dWeights(WeightsEnum): +class ResNeXt50_32X4D_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -182,7 +182,7 @@ class ResNeXt50_32x4dWeights(WeightsEnum): ) -class ResNeXt101_32x8dWeights(WeightsEnum): +class ResNeXt101_32X8D_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -207,7 +207,7 @@ class ResNeXt101_32x8dWeights(WeightsEnum): ) -class WideResNet50_2Weights(WeightsEnum): +class Wide_ResNet50_2_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -232,7 +232,7 @@ class WideResNet50_2Weights(WeightsEnum): ) -class WideResNet101_2Weights(WeightsEnum): +class Wide_ResNet101_2_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -257,97 +257,101 @@ class WideResNet101_2Weights(WeightsEnum): ) -def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18Weights.ImageNet1K_V1) - weights = ResNet18Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18_Weights.ImageNet1K_V1) + weights = ResNet18_Weights.verify(weights) return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) -def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet34(weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34Weights.ImageNet1K_V1) - weights = ResNet34Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34_Weights.ImageNet1K_V1) + weights = ResNet34_Weights.verify(weights) return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet50(weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50Weights.ImageNet1K_V1) - weights = ResNet50Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50_Weights.ImageNet1K_V1) + weights = ResNet50_Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet101(weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101Weights.ImageNet1K_V1) - weights = ResNet101Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101_Weights.ImageNet1K_V1) + weights = ResNet101_Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet152(weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152Weights.ImageNet1K_V1) - weights = ResNet152Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152_Weights.ImageNet1K_V1) + weights = ResNet152_Weights.verify(weights) return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) -def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32x4dWeights.ImageNet1K_V1) - weights = ResNeXt50_32x4dWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32X4D_Weights.ImageNet1K_V1) + weights = ResNeXt50_32X4D_Weights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "width_per_group", 4) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnext101_32x8d( + weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32x8dWeights.ImageNet1K_V1) - weights = ResNeXt101_32x8dWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32X8D_Weights.ImageNet1K_V1) + weights = ResNeXt101_32X8D_Weights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "width_per_group", 8) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def wide_resnet50_2(weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet50_2Weights.ImageNet1K_V1) - weights = WideResNet50_2Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet50_2_Weights.ImageNet1K_V1) + weights = Wide_ResNet50_2_Weights.verify(weights) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def wide_resnet101_2( + weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet101_2Weights.ImageNet1K_V1) - weights = WideResNet101_2Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet101_2_Weights.ImageNet1K_V1) + weights = Wide_ResNet101_2_Weights.verify(weights) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 80e0285a293..638b771c333 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -8,16 +8,16 @@ from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..resnet import resnet50, resnet101 -from ..resnet import ResNet50Weights, ResNet101Weights +from ..resnet import ResNet50_Weights, ResNet101_Weights __all__ = [ "DeepLabV3", - "DeepLabV3ResNet50Weights", - "DeepLabV3ResNet101Weights", - "DeepLabV3MobileNetV3LargeWeights", + "DeepLabV3_ResNet50_Weights", + "DeepLabV3_ResNet101_Weights", + "DeepLabV3_MobileNet_V3_Large_Weights", "deeplabv3_mobilenet_v3_large", "deeplabv3_resnet50", "deeplabv3_resnet101", @@ -30,7 +30,7 @@ } -class DeepLabV3ResNet50Weights(WeightsEnum): +class DeepLabV3_ResNet50_Weights(WeightsEnum): CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", transforms=partial(VocEval, resize_size=520), @@ -44,7 +44,7 @@ class DeepLabV3ResNet50Weights(WeightsEnum): ) -class DeepLabV3ResNet101Weights(WeightsEnum): +class DeepLabV3_ResNet101_Weights(WeightsEnum): CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", transforms=partial(VocEval, resize_size=520), @@ -58,7 +58,7 @@ class DeepLabV3ResNet101Weights(WeightsEnum): ) -class DeepLabV3MobileNetV3LargeWeights(WeightsEnum): +class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", transforms=partial(VocEval, resize_size=520), @@ -73,25 +73,25 @@ class DeepLabV3MobileNetV3LargeWeights(WeightsEnum): def deeplabv3_resnet50( - weights: Optional[DeepLabV3ResNet50Weights] = None, + weights: Optional[DeepLabV3_ResNet50_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, **kwargs: Any, ) -> DeepLabV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet50Weights.CocoWithVocLabels_V1) - weights = DeepLabV3ResNet50Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1) + weights = DeepLabV3_ResNet50_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -110,25 +110,25 @@ def deeplabv3_resnet50( def deeplabv3_resnet101( - weights: Optional[DeepLabV3ResNet101Weights] = None, + weights: Optional[DeepLabV3_ResNet101_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101Weights] = None, + weights_backbone: Optional[ResNet101_Weights] = None, **kwargs: Any, ) -> DeepLabV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet101Weights.CocoWithVocLabels_V1) - weights = DeepLabV3ResNet101Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1) + weights = DeepLabV3_ResNet101_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet101Weights.verify(weights_backbone) + weights_backbone = ResNet101_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -147,27 +147,27 @@ def deeplabv3_resnet101( def deeplabv3_mobilenet_v3_large( - weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None, + weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, **kwargs: Any, ) -> DeepLabV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param( - kwargs, "pretrained", "weights", DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_V1 + kwargs, "pretrained", "weights", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1 ) - weights = DeepLabV3MobileNetV3LargeWeights.verify(weights) + weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index d0444d24293..841e2ea95c5 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -8,10 +8,10 @@ from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 +from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 -__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"] +__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] _COMMON_META = { @@ -20,7 +20,7 @@ } -class FCNResNet50Weights(WeightsEnum): +class FCN_ResNet50_Weights(WeightsEnum): CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", transforms=partial(VocEval, resize_size=520), @@ -34,7 +34,7 @@ class FCNResNet50Weights(WeightsEnum): ) -class FCNResNet101Weights(WeightsEnum): +class FCN_ResNet101_Weights(WeightsEnum): CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", transforms=partial(VocEval, resize_size=520), @@ -49,25 +49,25 @@ class FCNResNet101Weights(WeightsEnum): def fcn_resnet50( - weights: Optional[FCNResNet50Weights] = None, + weights: Optional[FCN_ResNet50_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, **kwargs: Any, ) -> FCN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet50Weights.CocoWithVocLabels_V1) - weights = FCNResNet50Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet50_Weights.CocoWithVocLabels_V1) + weights = FCN_ResNet50_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -86,25 +86,25 @@ def fcn_resnet50( def fcn_resnet101( - weights: Optional[FCNResNet101Weights] = None, + weights: Optional[FCN_ResNet101_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101Weights] = None, + weights_backbone: Optional[ResNet101_Weights] = None, **kwargs: Any, ) -> FCN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet101Weights.CocoWithVocLabels_V1) - weights = FCNResNet101Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet101_Weights.CocoWithVocLabels_V1) + weights = FCN_ResNet101_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet101Weights.verify(weights_backbone) + weights_backbone = ResNet101_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 1fb989e2f62..9743e02fa16 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -8,13 +8,13 @@ from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -__all__ = ["LRASPP", "LRASPPMobileNetV3LargeWeights", "lraspp_mobilenet_v3_large"] +__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] -class LRASPPMobileNetV3LargeWeights(WeightsEnum): +class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", transforms=partial(VocEval, resize_size=520), @@ -30,10 +30,10 @@ class LRASPPMobileNetV3LargeWeights(WeightsEnum): def lraspp_mobilenet_v3_large( - weights: Optional[LRASPPMobileNetV3LargeWeights] = None, + weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, **kwargs: Any, ) -> LRASPP: if kwargs.pop("aux_loss", False): @@ -42,15 +42,17 @@ def lraspp_mobilenet_v3_large( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_V1) - weights = LRASPPMobileNetV3LargeWeights.verify(weights) + weights = _deprecated_param( + kwargs, "pretrained", "weights", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1 + ) + weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_V1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index 6a0eb790632..9fa98c44223 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -12,10 +12,10 @@ __all__ = [ "ShuffleNetV2", - "ShuffleNetV2_x0_5Weights", - "ShuffleNetV2_x1_0Weights", - "ShuffleNetV2_x1_5Weights", - "ShuffleNetV2_x2_0Weights", + "ShuffleNet_V2_X0_5_Weights", + "ShuffleNet_V2_X1_0_Weights", + "ShuffleNet_V2_X1_5_Weights", + "ShuffleNet_V2_X2_0_Weights", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", @@ -48,7 +48,7 @@ def _shufflenetv2( } -class ShuffleNetV2_x0_5Weights(WeightsEnum): +class ShuffleNet_V2_X0_5_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -61,7 +61,7 @@ class ShuffleNetV2_x0_5Weights(WeightsEnum): ) -class ShuffleNetV2_x1_0Weights(WeightsEnum): +class ShuffleNet_V2_X1_0_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -74,57 +74,57 @@ class ShuffleNetV2_x1_0Weights(WeightsEnum): ) -class ShuffleNetV2_x1_5Weights(WeightsEnum): +class ShuffleNet_V2_X1_5_Weights(WeightsEnum): pass -class ShuffleNetV2_x2_0Weights(WeightsEnum): +class ShuffleNet_V2_X2_0_Weights(WeightsEnum): pass def shufflenet_v2_x0_5( - weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x0_5Weights.ImageNet1K_V1) - weights = ShuffleNetV2_x0_5Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1) + weights = ShuffleNet_V2_X0_5_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) def shufflenet_v2_x1_0( - weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x1_0Weights.ImageNet1K_V1) - weights = ShuffleNetV2_x1_0Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1) + weights = ShuffleNet_V2_X1_0_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) def shufflenet_v2_x1_5( - weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = ShuffleNetV2_x1_5Weights.verify(weights) + weights = ShuffleNet_V2_X1_5_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) def shufflenet_v2_x2_0( - weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = ShuffleNetV2_x2_0Weights.verify(weights) + weights = ShuffleNet_V2_X2_0_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index 05cf6f811ff..fdfaa01e8be 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -10,7 +10,7 @@ from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"] +__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] _COMMON_META = { @@ -21,7 +21,7 @@ } -class SqueezeNet1_0Weights(WeightsEnum): +class SqueezeNet1_0_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -34,7 +34,7 @@ class SqueezeNet1_0Weights(WeightsEnum): ) -class SqueezeNet1_1Weights(WeightsEnum): +class SqueezeNet1_1_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -47,12 +47,12 @@ class SqueezeNet1_1Weights(WeightsEnum): ) -def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: +def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0Weights.ImageNet1K_V1) - weights = SqueezeNet1_0Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0_Weights.ImageNet1K_V1) + weights = SqueezeNet1_0_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) @@ -65,12 +65,12 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool return model -def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: +def squeezenet1_1(weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1Weights.ImageNet1K_V1) - weights = SqueezeNet1_1Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1_Weights.ImageNet1K_V1) + weights = SqueezeNet1_1_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index 9fdcbf4053e..a357426693d 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -12,14 +12,14 @@ __all__ = [ "VGG", - "VGG11Weights", - "VGG11BNWeights", - "VGG13Weights", - "VGG13BNWeights", - "VGG16Weights", - "VGG16BNWeights", - "VGG19Weights", - "VGG19BNWeights", + "VGG11_Weights", + "VGG11_BN_Weights", + "VGG13_Weights", + "VGG13_BN_Weights", + "VGG16_Weights", + "VGG16_BN_Weights", + "VGG19_Weights", + "VGG19_BN_Weights", "vgg11", "vgg11_bn", "vgg13", @@ -48,7 +48,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b } -class VGG11Weights(WeightsEnum): +class VGG11_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11-8a719046.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -61,7 +61,7 @@ class VGG11Weights(WeightsEnum): ) -class VGG11BNWeights(WeightsEnum): +class VGG11_BN_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -74,7 +74,7 @@ class VGG11BNWeights(WeightsEnum): ) -class VGG13Weights(WeightsEnum): +class VGG13_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13-19584684.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -87,7 +87,7 @@ class VGG13Weights(WeightsEnum): ) -class VGG13BNWeights(WeightsEnum): +class VGG13_BN_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -100,7 +100,7 @@ class VGG13BNWeights(WeightsEnum): ) -class VGG16Weights(WeightsEnum): +class VGG16_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16-397923af.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -131,7 +131,7 @@ class VGG16Weights(WeightsEnum): ) -class VGG16BNWeights(WeightsEnum): +class VGG16_BN_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -144,7 +144,7 @@ class VGG16BNWeights(WeightsEnum): ) -class VGG19Weights(WeightsEnum): +class VGG19_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -157,7 +157,7 @@ class VGG19Weights(WeightsEnum): ) -class VGG19BNWeights(WeightsEnum): +class VGG19_BN_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", transforms=partial(ImageNetEval, crop_size=224), @@ -170,81 +170,81 @@ class VGG19BNWeights(WeightsEnum): ) -def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11Weights.ImageNet1K_V1) - weights = VGG11Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_Weights.ImageNet1K_V1) + weights = VGG11_Weights.verify(weights) return _vgg("A", False, weights, progress, **kwargs) -def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg11_bn(weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11BNWeights.ImageNet1K_V1) - weights = VGG11BNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_BN_Weights.ImageNet1K_V1) + weights = VGG11_BN_Weights.verify(weights) return _vgg("A", True, weights, progress, **kwargs) -def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg13(weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13Weights.ImageNet1K_V1) - weights = VGG13Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_Weights.ImageNet1K_V1) + weights = VGG13_Weights.verify(weights) return _vgg("B", False, weights, progress, **kwargs) -def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg13_bn(weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13BNWeights.ImageNet1K_V1) - weights = VGG13BNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_BN_Weights.ImageNet1K_V1) + weights = VGG13_BN_Weights.verify(weights) return _vgg("B", True, weights, progress, **kwargs) -def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg16(weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16Weights.ImageNet1K_V1) - weights = VGG16Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_Weights.ImageNet1K_V1) + weights = VGG16_Weights.verify(weights) return _vgg("D", False, weights, progress, **kwargs) -def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg16_bn(weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16BNWeights.ImageNet1K_V1) - weights = VGG16BNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_BN_Weights.ImageNet1K_V1) + weights = VGG16_BN_Weights.verify(weights) return _vgg("D", True, weights, progress, **kwargs) -def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg19(weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19Weights.ImageNet1K_V1) - weights = VGG19Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_Weights.ImageNet1K_V1) + weights = VGG19_Weights.verify(weights) return _vgg("E", False, weights, progress, **kwargs) -def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg19_bn(weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19BNWeights.ImageNet1K_V1) - weights = VGG19BNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_BN_Weights.ImageNet1K_V1) + weights = VGG19_BN_Weights.verify(weights) return _vgg("E", True, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 32fc66d7d2a..c75f618a8b1 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -22,9 +22,9 @@ __all__ = [ "VideoResNet", - "R3D_18Weights", - "MC3_18Weights", - "R2Plus1D_18Weights", + "R3D_18_Weights", + "MC3_18_Weights", + "R2Plus1D_18_Weights", "r3d_18", "mc3_18", "r2plus1d_18", @@ -59,7 +59,7 @@ def _video_resnet( } -class R3D_18Weights(WeightsEnum): +class R3D_18_Weights(WeightsEnum): Kinetics400_V1 = Weights( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), @@ -72,7 +72,7 @@ class R3D_18Weights(WeightsEnum): ) -class MC3_18Weights(WeightsEnum): +class MC3_18_Weights(WeightsEnum): Kinetics400_V1 = Weights( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), @@ -85,7 +85,7 @@ class MC3_18Weights(WeightsEnum): ) -class R2Plus1D_18Weights(WeightsEnum): +class R2Plus1D_18_Weights(WeightsEnum): Kinetics400_V1 = Weights( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), @@ -98,12 +98,12 @@ class R2Plus1D_18Weights(WeightsEnum): ) -def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: +def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18Weights.Kinetics400_V1) - weights = R3D_18Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18_Weights.Kinetics400_V1) + weights = R3D_18_Weights.verify(weights) return _video_resnet( BasicBlock, @@ -116,12 +116,12 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa ) -def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: +def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18Weights.Kinetics400_V1) - weights = MC3_18Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18_Weights.Kinetics400_V1) + weights = MC3_18_Weights.verify(weights) return _video_resnet( BasicBlock, @@ -134,12 +134,12 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa ) -def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: +def r2plus1d_18(weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18Weights.Kinetics400_V1) - weights = R2Plus1D_18Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18_Weights.Kinetics400_V1) + weights = R2Plus1D_18_Weights.verify(weights) return _video_resnet( BasicBlock, diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 1f8673b6588..bbe5aba262c 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -17,10 +17,10 @@ __all__ = [ "VisionTransformer", - "VisionTransformer_B_16Weights", - "VisionTransformer_B_32Weights", - "VisionTransformer_L_16Weights", - "VisionTransformer_L_32Weights", + "ViT_B_16_Weights", + "ViT_B_32_Weights", + "ViT_L_16_Weights", + "ViT_L_32_Weights", "vit_b_16", "vit_b_32", "vit_l_16", @@ -231,22 +231,22 @@ def forward(self, x: torch.Tensor): return x -class VisionTransformer_B_16Weights(WeightsEnum): +class ViT_B_16_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_b_16 pass -class VisionTransformer_B_32Weights(WeightsEnum): +class ViT_B_32_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_b_32 pass -class VisionTransformer_L_16Weights(WeightsEnum): +class ViT_L_16_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_l_16 pass -class VisionTransformer_L_32Weights(WeightsEnum): +class ViT_L_32_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_l_32 pass @@ -279,15 +279,13 @@ def _vision_transformer( return model -def vit_b_16( - weights: Optional[VisionTransformer_B_16Weights] = None, progress: bool = True, **kwargs: Any -) -> VisionTransformer: +def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - weights (VisionTransformer_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet. + weights (ViT_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet. Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ @@ -295,7 +293,7 @@ def vit_b_16( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = VisionTransformer_B_16Weights.verify(weights) + weights = ViT_B_16_Weights.verify(weights) return _vision_transformer( patch_size=16, @@ -309,15 +307,13 @@ def vit_b_16( ) -def vit_b_32( - weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any -) -> VisionTransformer: +def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - weights (VisionTransformer_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet. + weights (ViT_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet. Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ @@ -325,7 +321,7 @@ def vit_b_32( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = VisionTransformer_B_32Weights.verify(weights) + weights = ViT_B_32_Weights.verify(weights) return _vision_transformer( patch_size=32, @@ -339,15 +335,13 @@ def vit_b_32( ) -def vit_l_16( - weights: Optional[VisionTransformer_L_16Weights] = None, progress: bool = True, **kwargs: Any -) -> VisionTransformer: +def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. + weights (ViT_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ @@ -355,7 +349,7 @@ def vit_l_16( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = VisionTransformer_L_16Weights.verify(weights) + weights = ViT_L_16_Weights.verify(weights) return _vision_transformer( patch_size=16, @@ -369,15 +363,13 @@ def vit_l_16( ) -def vit_l_32( - weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any -) -> VisionTransformer: +def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. + weights (ViT_L_32Weights, optional): If not None, returns a model pre-trained on ImageNet. Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ @@ -385,7 +377,7 @@ def vit_l_32( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = VisionTransformer_L_32Weights.verify(weights) + weights = ViT_L_32_Weights.verify(weights) return _vision_transformer( patch_size=32, From 4b1715bbd91b2bd63c8ab6ba8d7ecf15b36b44a1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 29 Nov 2021 15:13:42 +0000 Subject: [PATCH 4/4] Add a test to check naming conventions. --- test/test_prototype_models.py | 41 +++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 14522094b41..92a88342534 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -18,6 +18,12 @@ def _get_original_model(model_fn): return module.__dict__[model_fn.__name__] +def _get_parent_module(model_fn): + parent_module_name = ".".join(model_fn.__module__.split(".")[:-1]) + module = importlib.import_module(parent_module_name) + return module + + def _build_model(fn, **kwargs): try: model = fn(**kwargs) @@ -29,11 +35,6 @@ def _build_model(fn, **kwargs): return model.eval() -def get_models_with_module_names(module): - module_name = module.__name__.split(".")[-1] - return [(fn, module_name) for fn in TM.get_models_from_module(module)] - - @pytest.mark.parametrize( "model_fn, name, weight", [ @@ -55,6 +56,21 @@ def test_get_weight(model_fn, name, weight): assert models._api.get_weight(model_fn, name) == weight +@pytest.mark.parametrize( + "model_fn", + TM.get_models_from_module(models) + + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(models.segmentation) + + TM.get_models_from_module(models.video), +) +def test_naming_conventions(model_fn): + model_name = model_fn.__name__ + module = _get_parent_module(model_fn) + weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights" + assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name)) + + @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -90,16 +106,16 @@ def test_video_model(model_fn, dev): @pytest.mark.parametrize( - "model_fn, module_name", - get_models_with_module_names(models) - + get_models_with_module_names(models.detection) - + get_models_with_module_names(models.quantization) - + get_models_with_module_names(models.segmentation) - + get_models_with_module_names(models.video), + "model_fn", + TM.get_models_from_module(models) + + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(models.segmentation) + + TM.get_models_from_module(models.video), ) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype -def test_old_vs_new_factory(model_fn, module_name, dev): +def test_old_vs_new_factory(model_fn, dev): defaults = { "models": { "input_shape": (1, 3, 224, 224), @@ -119,6 +135,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev): }, } model_name = model_fn.__name__ + module_name = model_fn.__module__.split(".")[-2] kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})} input_shape = kwargs.pop("input_shape") kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models