diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index f53299bcf51..92a88342534 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -18,6 +18,12 @@ def _get_original_model(model_fn): return module.__dict__[model_fn.__name__] +def _get_parent_module(model_fn): + parent_module_name = ".".join(model_fn.__module__.split(".")[:-1]) + module = importlib.import_module(parent_module_name) + return module + + def _build_model(fn, **kwargs): try: model = fn(**kwargs) @@ -29,20 +35,20 @@ def _build_model(fn, **kwargs): return model.eval() -def get_models_with_module_names(module): - module_name = module.__name__.split(".")[-1] - return [(fn, module_name) for fn in TM.get_models_from_module(module)] - - @pytest.mark.parametrize( "model_fn, name, weight", [ - (models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1), - (models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2), + (models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1), + (models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2), + ( + models.quantization.resnet50, + "default", + models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2, + ), ( models.quantization.resnet50, - "ImageNet1K_FBGEMM_RefV1", - models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1, + "ImageNet1K_FBGEMM_V1", + models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1, ), ], ) @@ -50,6 +56,21 @@ def test_get_weight(model_fn, name, weight): assert models._api.get_weight(model_fn, name) == weight +@pytest.mark.parametrize( + "model_fn", + TM.get_models_from_module(models) + + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(models.segmentation) + + TM.get_models_from_module(models.video), +) +def test_naming_conventions(model_fn): + model_name = model_fn.__name__ + module = _get_parent_module(model_fn) + weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights" + assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name)) + + @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -85,16 +106,16 @@ def test_video_model(model_fn, dev): @pytest.mark.parametrize( - "model_fn, module_name", - get_models_with_module_names(models) - + get_models_with_module_names(models.detection) - + get_models_with_module_names(models.quantization) - + get_models_with_module_names(models.segmentation) - + get_models_with_module_names(models.video), + "model_fn", + TM.get_models_from_module(models) + + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(models.quantization) + + TM.get_models_from_module(models.segmentation) + + TM.get_models_from_module(models.video), ) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype -def test_old_vs_new_factory(model_fn, module_name, dev): +def test_old_vs_new_factory(model_fn, dev): defaults = { "models": { "input_shape": (1, 3, 224, 224), @@ -114,6 +135,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev): }, } model_name = model_fn.__name__ + module_name = model_fn.__module__.split(".")[-2] kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})} input_shape = kwargs.pop("input_shape") kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models diff --git a/torchvision/prototype/models/_api.py b/torchvision/prototype/models/_api.py index 6291f8eedcf..2935039e087 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/prototype/models/_api.py @@ -7,11 +7,11 @@ from ..._internally_replaced_utils import load_state_dict_from_url -__all__ = ["Weights", "WeightEntry", "get_weight"] +__all__ = ["WeightsEnum", "Weights", "get_weight"] @dataclass -class WeightEntry: +class Weights: """ This class is used to group important attributes associated with the pre-trained weights. @@ -33,17 +33,17 @@ class WeightEntry: default: bool -class Weights(Enum): +class WeightsEnum(Enum): """ This class is the parent class of all model weights. Each model building method receives an optional `weights` parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type - `WeightEntry`. + `Weights`. Args: - value (WeightEntry): The data class entry with the weight information. + value (Weights): The data class entry with the weight information. """ - def __init__(self, value: WeightEntry): + def __init__(self, value: Weights): self._value_ = value @classmethod @@ -58,7 +58,7 @@ def verify(cls, obj: Any) -> Any: return obj @classmethod - def from_str(cls, value: str) -> "Weights": + def from_str(cls, value: str) -> "WeightsEnum": for v in cls: if v._name_ == value or (value == "default" and v.default): return v @@ -71,14 +71,14 @@ def __repr__(self): return f"{self.__class__.__name__}.{self._name_}" def __getattr__(self, name): - # Be able to fetch WeightEntry attributes directly - for f in fields(WeightEntry): + # Be able to fetch Weights attributes directly + for f in fields(Weights): if f.name == name: return object.__getattribute__(self.value, name) return super().__getattr__(name) -def get_weight(fn: Callable, weight_name: str) -> Weights: +def get_weight(fn: Callable, weight_name: str) -> WeightsEnum: """ Gets the weight enum of a specific model builder method and weight name combination. @@ -87,32 +87,32 @@ def get_weight(fn: Callable, weight_name: str) -> Weights: weight_name (str): The name of the weight enum entry of the specific model. Returns: - Weights: The requested weight enum. + WeightsEnum: The requested weight enum. """ sig = signature(fn) if "weights" not in sig.parameters: raise ValueError("The method is missing the 'weights' parameter.") ann = signature(fn).parameters["weights"].annotation - weights_class = None - if isinstance(ann, type) and issubclass(ann, Weights): - weights_class = ann + weights_enum = None + if isinstance(ann, type) and issubclass(ann, WeightsEnum): + weights_enum = ann else: # handle cases like Union[Optional, T] # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 for t in ann.__args__: # type: ignore[union-attr] - if isinstance(t, type) and issubclass(t, Weights): + if isinstance(t, type) and issubclass(t, WeightsEnum): # ensure the name exists. handles builders with multiple types of weights like in quantization try: t.from_str(weight_name) except ValueError: continue - weights_class = t + weights_enum = t break - if weights_class is None: + if weights_enum is None: raise ValueError( "The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct." ) - return weights_class.from_str(weight_name) + return weights_enum.from_str(weight_name) diff --git a/torchvision/prototype/models/_utils.py b/torchvision/prototype/models/_utils.py index 31d9703db1f..e2ee9034953 100644 --- a/torchvision/prototype/models/_utils.py +++ b/torchvision/prototype/models/_utils.py @@ -1,10 +1,10 @@ import warnings from typing import Any, Dict, Optional, TypeVar -from ._api import Weights +from ._api import WeightsEnum -W = TypeVar("W", bound=Weights) +W = TypeVar("W", bound=WeightsEnum) V = TypeVar("V") diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index d5e826b4559..b45ca1e7085 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -5,16 +5,16 @@ from torchvision.transforms.functional import InterpolationMode from ...models.alexnet import AlexNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["AlexNet", "AlexNetWeights", "alexnet"] +__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] -class AlexNetWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class AlexNet_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -29,12 +29,12 @@ class AlexNetWeights(Weights): ) -def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: +def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1) - weights = AlexNetWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1) + weights = AlexNet_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index f70cee3a528..e779a2cd239 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -7,17 +7,17 @@ from torchvision.transforms.functional import InterpolationMode from ...models.densenet import DenseNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param __all__ = [ "DenseNet", - "DenseNet121Weights", - "DenseNet161Weights", - "DenseNet169Weights", - "DenseNet201Weights", + "DenseNet121_Weights", + "DenseNet161_Weights", + "DenseNet169_Weights", + "DenseNet201_Weights", "densenet121", "densenet161", "densenet169", @@ -25,7 +25,7 @@ ] -def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None: +def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None: # '.'s are no longer allowed in module names, but previous _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used @@ -48,7 +48,7 @@ def _densenet( growth_rate: int, block_config: Tuple[int, int, int, int], num_init_features: int, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> DenseNet: @@ -71,8 +71,8 @@ def _densenet( } -class DenseNet121Weights(Weights): - ImageNet1K_Community = WeightEntry( +class DenseNet121_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet121-a639ec97.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -84,8 +84,8 @@ class DenseNet121Weights(Weights): ) -class DenseNet161Weights(Weights): - ImageNet1K_Community = WeightEntry( +class DenseNet161_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet161-8d451a50.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -97,8 +97,8 @@ class DenseNet161Weights(Weights): ) -class DenseNet169Weights(Weights): - ImageNet1K_Community = WeightEntry( +class DenseNet169_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -110,8 +110,8 @@ class DenseNet169Weights(Weights): ) -class DenseNet201Weights(Weights): - ImageNet1K_Community = WeightEntry( +class DenseNet201_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/densenet201-c1103571.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -123,41 +123,41 @@ class DenseNet201Weights(Weights): ) -def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: +def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_Community) - weights = DenseNet121Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1) + weights = DenseNet121_Weights.verify(weights) return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs) -def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: +def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_Community) - weights = DenseNet161Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1) + weights = DenseNet161_Weights.verify(weights) return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs) -def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: +def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_Community) - weights = DenseNet169Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1) + weights = DenseNet169_Weights.verify(weights) return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs) -def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: +def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_Community) - weights = DenseNet201Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1) + weights = DenseNet201_Weights.verify(weights) return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 685a309beaa..c83aaf222fb 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -12,18 +12,18 @@ misc_nn_ops, overwrite_eps, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large -from ..resnet import ResNet50Weights, resnet50 +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from ..resnet import ResNet50_Weights, resnet50 __all__ = [ "FasterRCNN", - "FasterRCNNResNet50FPNWeights", - "FasterRCNNMobileNetV3LargeFPNWeights", - "FasterRCNNMobileNetV3Large320FPNWeights", + "FasterRCNN_ResNet50_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn", @@ -36,8 +36,8 @@ } -class FasterRCNNResNet50FPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): + Coco_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", transforms=CocoEval, meta={ @@ -49,8 +49,8 @@ class FasterRCNNResNet50FPNWeights(Weights): ) -class FasterRCNNMobileNetV3LargeFPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): + Coco_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", transforms=CocoEval, meta={ @@ -62,8 +62,8 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights): ) -class FasterRCNNMobileNetV3Large320FPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): + Coco_V1 = Weights( url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", transforms=CocoEval, meta={ @@ -76,25 +76,25 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights): def fasterrcnn_resnet50_fpn( - weights: Optional[FasterRCNNResNet50FPNWeights] = None, + weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNResNet50FPNWeights.Coco_RefV1) - weights = FasterRCNNResNet50FPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_ResNet50_FPN_Weights.Coco_V1) + weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -112,17 +112,17 @@ def fasterrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1: + if weights == FasterRCNN_ResNet50_FPN_Weights.Coco_V1: overwrite_eps(model, 0.0) return model def _fasterrcnn_mobilenet_v3_large_fpn( - weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]], + weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], progress: bool, num_classes: Optional[int], - weights_backbone: Optional[MobileNetV3LargeWeights], + weights_backbone: Optional[MobileNet_V3_Large_Weights], trainable_backbone_layers: Optional[int], **kwargs: Any, ) -> FasterRCNN: @@ -159,25 +159,25 @@ def _fasterrcnn_mobilenet_v3_large_fpn( def fasterrcnn_mobilenet_v3_large_fpn( - weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None, + weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1) - weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1) + weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) defaults = { "rpn_score_thresh": 0.05, @@ -195,25 +195,27 @@ def fasterrcnn_mobilenet_v3_large_fpn( def fasterrcnn_mobilenet_v3_large_320_fpn( - weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None, + weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> FasterRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1) - weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) + weights = _deprecated_param( + kwargs, "pretrained", "weights", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1 + ) + weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) defaults = { "min_size": 320, diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index d45d7f93af7..85250ac2a33 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -9,15 +9,15 @@ misc_nn_ops, overwrite_eps, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..resnet import ResNet50Weights, resnet50 +from ..resnet import ResNet50_Weights, resnet50 __all__ = [ "KeypointRCNN", - "KeypointRCNNResNet50FPNWeights", + "KeypointRCNN_ResNet50_FPN_Weights", "keypointrcnn_resnet50_fpn", ] @@ -25,8 +25,8 @@ _COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES} -class KeypointRCNNResNet50FPNWeights(Weights): - Coco_RefV1_Legacy = WeightEntry( +class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): + Coco_Legacy = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", transforms=CocoEval, meta={ @@ -37,7 +37,7 @@ class KeypointRCNNResNet50FPNWeights(Weights): }, default=False, ) - Coco_RefV1 = WeightEntry( + Coco_V1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", transforms=CocoEval, meta={ @@ -51,30 +51,30 @@ class KeypointRCNNResNet50FPNWeights(Weights): def keypointrcnn_resnet50_fpn( - weights: Optional[KeypointRCNNResNet50FPNWeights] = None, + weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, num_keypoints: Optional[int] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> KeypointRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1 + default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_V1 if kwargs["pretrained"] == "legacy": - default_value = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy + default_value = KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy kwargs["pretrained"] = True weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) - weights = KeypointRCNNResNet50FPNWeights.verify(weights) + weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -96,7 +96,7 @@ def keypointrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1: + if weights == KeypointRCNN_ResNet50_FPN_Weights.Coco_V1: overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index 8aba8ce5041..ea7ab4f5fc7 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -10,21 +10,21 @@ misc_nn_ops, overwrite_eps, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..resnet import ResNet50Weights, resnet50 +from ..resnet import ResNet50_Weights, resnet50 __all__ = [ "MaskRCNN", - "MaskRCNNResNet50FPNWeights", + "MaskRCNN_ResNet50_FPN_Weights", "maskrcnn_resnet50_fpn", ] -class MaskRCNNResNet50FPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): + Coco_V1 = Weights( url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", transforms=CocoEval, meta={ @@ -39,25 +39,25 @@ class MaskRCNNResNet50FPNWeights(Weights): def maskrcnn_resnet50_fpn( - weights: Optional[MaskRCNNResNet50FPNWeights] = None, + weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> MaskRCNN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNNResNet50FPNWeights.Coco_RefV1) - weights = MaskRCNNResNet50FPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MaskRCNN_ResNet50_FPN_Weights.Coco_V1) + weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -75,7 +75,7 @@ def maskrcnn_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1: + if weights == MaskRCNN_ResNet50_FPN_Weights.Coco_V1: overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index c9361934921..d442c79d5b6 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -11,21 +11,21 @@ misc_nn_ops, overwrite_eps, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..resnet import ResNet50Weights, resnet50 +from ..resnet import ResNet50_Weights, resnet50 __all__ = [ "RetinaNet", - "RetinaNetResNet50FPNWeights", + "RetinaNet_ResNet50_FPN_Weights", "retinanet_resnet50_fpn", ] -class RetinaNetResNet50FPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): + Coco_V1 = Weights( url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", transforms=CocoEval, meta={ @@ -39,25 +39,25 @@ class RetinaNetResNet50FPNWeights(Weights): def retinanet_resnet50_fpn( - weights: Optional[RetinaNetResNet50FPNWeights] = None, + weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> RetinaNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNetResNet50FPNWeights.Coco_RefV1) - weights = RetinaNetResNet50FPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RetinaNet_ResNet50_FPN_Weights.Coco_V1) + weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -78,7 +78,7 @@ def retinanet_resnet50_fpn( if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == RetinaNetResNet50FPNWeights.Coco_RefV1: + if weights == RetinaNet_ResNet50_FPN_Weights.Coco_V1: overwrite_eps(model, 0.0) return model diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 09ce083bf7e..37f5c2a6944 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -10,20 +10,20 @@ DefaultBoxGenerator, SSD, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..vgg import VGG16Weights, vgg16 +from ..vgg import VGG16_Weights, vgg16 __all__ = [ - "SSD300VGG16Weights", + "SSD300_VGG16_Weights", "ssd300_vgg16", ] -class SSD300VGG16Weights(Weights): - Coco_RefV1 = WeightEntry( +class SSD300_VGG16_Weights(WeightsEnum): + Coco_V1 = Weights( url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", transforms=CocoEval, meta={ @@ -38,25 +38,25 @@ class SSD300VGG16Weights(Weights): def ssd300_vgg16( - weights: Optional[SSD300VGG16Weights] = None, + weights: Optional[SSD300_VGG16_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[VGG16Weights] = None, + weights_backbone: Optional[VGG16_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, ) -> SSD: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300VGG16Weights.Coco_RefV1) - weights = SSD300VGG16Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", SSD300_VGG16_Weights.Coco_V1) + weights = SSD300_VGG16_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", VGG16Weights.ImageNet1K_Features + kwargs, "pretrained_backbone", "weights_backbone", VGG16_Weights.ImageNet1K_Features ) - weights_backbone = VGG16Weights.verify(weights_backbone) + weights_backbone = VGG16_Weights.verify(weights_backbone) if "size" in kwargs: warnings.warn("The size of the model is already fixed; ignoring the parameter.") diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 0e2786ea203..309362f2f11 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -15,20 +15,20 @@ SSD, SSDLiteHead, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _COCO_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large __all__ = [ - "SSDlite320MobileNetV3LargeFPNWeights", + "SSDLite320_MobileNet_V3_Large_Weights", "ssdlite320_mobilenet_v3_large", ] -class SSDlite320MobileNetV3LargeFPNWeights(Weights): - Coco_RefV1 = WeightEntry( +class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): + Coco_V1 = Weights( url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", transforms=CocoEval, meta={ @@ -43,10 +43,10 @@ class SSDlite320MobileNetV3LargeFPNWeights(Weights): def ssdlite320_mobilenet_v3_large( - weights: Optional[SSDlite320MobileNetV3LargeFPNWeights] = None, + weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, trainable_backbone_layers: Optional[int] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any, @@ -54,15 +54,15 @@ def ssdlite320_mobilenet_v3_large( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1) - weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", SSDLite320_MobileNet_V3_Large_Weights.Coco_V1) + weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) if "size" in kwargs: warnings.warn("The size of the model is already fixed; ignoring the parameter.") diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 66dbd8ef5ea..74ca6ccc71d 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -6,21 +6,21 @@ from torchvision.transforms.functional import InterpolationMode from ...models.efficientnet import EfficientNet, MBConvConfig -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param __all__ = [ "EfficientNet", - "EfficientNetB0Weights", - "EfficientNetB1Weights", - "EfficientNetB2Weights", - "EfficientNetB3Weights", - "EfficientNetB4Weights", - "EfficientNetB5Weights", - "EfficientNetB6Weights", - "EfficientNetB7Weights", + "EfficientNet_B0_Weights", + "EfficientNet_B1_Weights", + "EfficientNet_B2_Weights", + "EfficientNet_B3_Weights", + "EfficientNet_B4_Weights", + "EfficientNet_B5_Weights", + "EfficientNet_B6_Weights", + "EfficientNet_B7_Weights", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", @@ -36,7 +36,7 @@ def _efficientnet( width_mult: float, depth_mult: float, dropout: float, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> EfficientNet: @@ -69,8 +69,8 @@ def _efficientnet( } -class EfficientNetB0Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNet_B0_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), meta={ @@ -83,8 +83,8 @@ class EfficientNetB0Weights(Weights): ) -class EfficientNetB1Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNet_B1_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), meta={ @@ -97,8 +97,8 @@ class EfficientNetB1Weights(Weights): ) -class EfficientNetB2Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNet_B2_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), meta={ @@ -111,8 +111,8 @@ class EfficientNetB2Weights(Weights): ) -class EfficientNetB3Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNet_B3_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), meta={ @@ -125,8 +125,8 @@ class EfficientNetB3Weights(Weights): ) -class EfficientNetB4Weights(Weights): - ImageNet1K_TimmV1 = WeightEntry( +class EfficientNet_B4_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), meta={ @@ -139,8 +139,8 @@ class EfficientNetB4Weights(Weights): ) -class EfficientNetB5Weights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class EfficientNet_B5_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), meta={ @@ -153,8 +153,8 @@ class EfficientNetB5Weights(Weights): ) -class EfficientNetB6Weights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class EfficientNet_B6_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), meta={ @@ -167,8 +167,8 @@ class EfficientNetB6Weights(Weights): ) -class EfficientNetB7Weights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class EfficientNet_B7_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), meta={ @@ -182,73 +182,73 @@ class EfficientNetB7Weights(Weights): def efficientnet_b0( - weights: Optional[EfficientNetB0Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB0Weights.ImageNet1K_TimmV1) - weights = EfficientNetB0Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B0_Weights.ImageNet1K_V1) + weights = EfficientNet_B0_Weights.verify(weights) return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) def efficientnet_b1( - weights: Optional[EfficientNetB1Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB1Weights.ImageNet1K_TimmV1) - weights = EfficientNetB1Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B1_Weights.ImageNet1K_V1) + weights = EfficientNet_B1_Weights.verify(weights) return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) def efficientnet_b2( - weights: Optional[EfficientNetB2Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB2Weights.ImageNet1K_TimmV1) - weights = EfficientNetB2Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B2_Weights.ImageNet1K_V1) + weights = EfficientNet_B2_Weights.verify(weights) return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) def efficientnet_b3( - weights: Optional[EfficientNetB3Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB3Weights.ImageNet1K_TimmV1) - weights = EfficientNetB3Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B3_Weights.ImageNet1K_V1) + weights = EfficientNet_B3_Weights.verify(weights) return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) def efficientnet_b4( - weights: Optional[EfficientNetB4Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB4Weights.ImageNet1K_TimmV1) - weights = EfficientNetB4Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B4_Weights.ImageNet1K_V1) + weights = EfficientNet_B4_Weights.verify(weights) return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) def efficientnet_b5( - weights: Optional[EfficientNetB5Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB5Weights.ImageNet1K_TFV1) - weights = EfficientNetB5Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B5_Weights.ImageNet1K_V1) + weights = EfficientNet_B5_Weights.verify(weights) return _efficientnet( width_mult=1.6, @@ -262,13 +262,13 @@ def efficientnet_b5( def efficientnet_b6( - weights: Optional[EfficientNetB6Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB6Weights.ImageNet1K_TFV1) - weights = EfficientNetB6Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B6_Weights.ImageNet1K_V1) + weights = EfficientNet_B6_Weights.verify(weights) return _efficientnet( width_mult=1.8, @@ -282,13 +282,13 @@ def efficientnet_b6( def efficientnet_b7( - weights: Optional[EfficientNetB7Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any ) -> EfficientNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNetB7Weights.ImageNet1K_TFV1) - weights = EfficientNetB7Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", EfficientNet_B7_Weights.ImageNet1K_V1) + weights = EfficientNet_B7_Weights.verify(weights) return _efficientnet( width_mult=2.0, diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 8a65aa08d35..352c49d1a2e 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -6,16 +6,16 @@ from torchvision.transforms.functional import InterpolationMode from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"] +__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"] -class GoogLeNetWeights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class GoogLeNet_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/googlenet-1378be20.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -30,12 +30,12 @@ class GoogLeNetWeights(Weights): ) -def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: +def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNetWeights.ImageNet1K_TFV1) - weights = GoogLeNetWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", GoogLeNet_Weights.ImageNet1K_V1) + weights = GoogLeNet_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index fe5cde184e0..9837b1fc4a6 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -5,16 +5,16 @@ from torchvision.transforms.functional import InterpolationMode from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "InceptionV3Weights", "inception_v3"] +__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"] -class InceptionV3Weights(Weights): - ImageNet1K_TFV1 = WeightEntry( +class Inception_V3_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", transforms=partial(ImageNetEval, crop_size=299, resize_size=342), meta={ @@ -29,12 +29,12 @@ class InceptionV3Weights(Weights): ) -def inception_v3(weights: Optional[InceptionV3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: +def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", InceptionV3Weights.ImageNet1K_TFV1) - weights = InceptionV3Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", Inception_V3_Weights.ImageNet1K_V1) + weights = Inception_V3_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", True) if weights is not None: diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index c5cef01fb98..73aaea0beca 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -5,17 +5,17 @@ from torchvision.transforms.functional import InterpolationMode from ...models.mnasnet import MNASNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param __all__ = [ "MNASNet", - "MNASNet0_5Weights", - "MNASNet0_75Weights", - "MNASNet1_0Weights", - "MNASNet1_3Weights", + "MNASNet0_5_Weights", + "MNASNet0_75_Weights", + "MNASNet1_0_Weights", + "MNASNet1_3_Weights", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", @@ -31,8 +31,8 @@ } -class MNASNet0_5Weights(Weights): - ImageNet1K_Community = WeightEntry( +class MNASNet0_5_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -44,13 +44,13 @@ class MNASNet0_5Weights(Weights): ) -class MNASNet0_75Weights(Weights): +class MNASNet0_75_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in mnasnet0_75 pass -class MNASNet1_0Weights(Weights): - ImageNet1K_Community = WeightEntry( +class MNASNet1_0_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -62,12 +62,12 @@ class MNASNet1_0Weights(Weights): ) -class MNASNet1_3Weights(Weights): +class MNASNet1_3_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in mnasnet1_3 pass -def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: Any) -> MNASNet: +def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) @@ -79,41 +79,41 @@ def _mnasnet(alpha: float, weights: Optional[Weights], progress: bool, **kwargs: return model -def mnasnet0_5(weights: Optional[MNASNet0_5Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: +def mnasnet0_5(weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5Weights.ImageNet1K_Community) - weights = MNASNet0_5Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet0_5_Weights.ImageNet1K_V1) + weights = MNASNet0_5_Weights.verify(weights) return _mnasnet(0.5, weights, progress, **kwargs) -def mnasnet0_75(weights: Optional[MNASNet0_75Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: +def mnasnet0_75(weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = MNASNet0_75Weights.verify(weights) + weights = MNASNet0_75_Weights.verify(weights) return _mnasnet(0.75, weights, progress, **kwargs) -def mnasnet1_0(weights: Optional[MNASNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: +def mnasnet1_0(weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0Weights.ImageNet1K_Community) - weights = MNASNet1_0Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MNASNet1_0_Weights.ImageNet1K_V1) + weights = MNASNet1_0_Weights.verify(weights) return _mnasnet(1.0, weights, progress, **kwargs) -def mnasnet1_3(weights: Optional[MNASNet1_3Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: +def mnasnet1_3(weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = MNASNet1_3Weights.verify(weights) + weights = MNASNet1_3_Weights.verify(weights) return _mnasnet(1.3, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 1e33384ad2d..0c0f80d081a 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -5,16 +5,16 @@ from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv2 import MobileNetV2 -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"] +__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] -class MobileNetV2Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class MobileNet_V2_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -29,12 +29,12 @@ class MobileNetV2Weights(Weights): ) -def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: +def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV2Weights.ImageNet1K_RefV1) - weights = MobileNetV2Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V2_Weights.ImageNet1K_V1) + weights = MobileNet_V2_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index 1a6a810856b..e014fb5acb2 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -5,15 +5,15 @@ from torchvision.transforms.functional import InterpolationMode from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param __all__ = [ "MobileNetV3", - "MobileNetV3LargeWeights", - "MobileNetV3SmallWeights", + "MobileNet_V3_Large_Weights", + "MobileNet_V3_Small_Weights", "mobilenet_v3_large", "mobilenet_v3_small", ] @@ -22,7 +22,7 @@ def _mobilenet_v3( inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> MobileNetV3: @@ -44,8 +44,8 @@ def _mobilenet_v3( } -class MobileNetV3LargeWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class MobileNet_V3_Large_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -56,7 +56,7 @@ class MobileNetV3LargeWeights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -69,8 +69,8 @@ class MobileNetV3LargeWeights(Weights): ) -class MobileNetV3SmallWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class MobileNet_V3_Small_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -84,26 +84,26 @@ class MobileNetV3SmallWeights(Weights): def mobilenet_v3_large( - weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any + weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any ) -> MobileNetV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3LargeWeights.ImageNet1K_RefV1) - weights = MobileNetV3LargeWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Large_Weights.ImageNet1K_V1) + weights = MobileNet_V3_Large_Weights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) def mobilenet_v3_small( - weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any + weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any ) -> MobileNetV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNetV3SmallWeights.ImageNet1K_RefV1) - weights = MobileNetV3SmallWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MobileNet_V3_Small_Weights.ImageNet1K_V1) + weights = MobileNet_V3_Small_Weights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index 4769b115ba5..3d26fd7d607 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -10,21 +10,21 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..googlenet import GoogLeNetWeights +from ..googlenet import GoogLeNet_Weights __all__ = [ "QuantizableGoogLeNet", - "QuantizedGoogLeNetWeights", + "GoogLeNet_QuantizedWeights", "googlenet", ] -class QuantizedGoogLeNetWeights(Weights): - ImageNet1K_FBGEMM_TFV1 = WeightEntry( +class GoogLeNet_QuantizedWeights(WeightsEnum): + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -34,7 +34,7 @@ class QuantizedGoogLeNetWeights(Weights): "backend": "fbgemm", "quantization": "ptq", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": GoogLeNetWeights.ImageNet1K_TFV1, + "unquantized": GoogLeNet_Weights.ImageNet1K_V1, "acc@1": 69.826, "acc@5": 89.404, }, @@ -43,7 +43,7 @@ class QuantizedGoogLeNetWeights(Weights): def googlenet( - weights: Optional[Union[QuantizedGoogLeNetWeights, GoogLeNetWeights]] = None, + weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -51,14 +51,12 @@ def googlenet( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = ( - QuantizedGoogLeNetWeights.ImageNet1K_FBGEMM_TFV1 if quantize else GoogLeNetWeights.ImageNet1K_TFV1 - ) + default_value = GoogLeNet_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else GoogLeNet_Weights.ImageNet1K_V1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedGoogLeNetWeights.verify(weights) + weights = GoogLeNet_QuantizedWeights.verify(weights) else: - weights = GoogLeNetWeights.verify(weights) + weights = GoogLeNet_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index 4949d0f4b2d..ff779076df6 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -9,21 +9,21 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..inception import InceptionV3Weights +from ..inception import Inception_V3_Weights __all__ = [ "QuantizableInception3", - "QuantizedInceptionV3Weights", + "Inception_V3_QuantizedWeights", "inception_v3", ] -class QuantizedInceptionV3Weights(Weights): - ImageNet1K_FBGEMM_TFV1 = WeightEntry( +class Inception_V3_QuantizedWeights(WeightsEnum): + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", transforms=partial(ImageNetEval, crop_size=299, resize_size=342), meta={ @@ -33,7 +33,7 @@ class QuantizedInceptionV3Weights(Weights): "backend": "fbgemm", "quantization": "ptq", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models", - "unquantized": InceptionV3Weights.ImageNet1K_TFV1, + "unquantized": Inception_V3_Weights.ImageNet1K_V1, "acc@1": 77.176, "acc@5": 93.354, }, @@ -42,7 +42,7 @@ class QuantizedInceptionV3Weights(Weights): def inception_v3( - weights: Optional[Union[QuantizedInceptionV3Weights, InceptionV3Weights]] = None, + weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -51,13 +51,13 @@ def inception_v3( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedInceptionV3Weights.ImageNet1K_FBGEMM_TFV1 if quantize else InceptionV3Weights.ImageNet1K_TFV1 + Inception_V3_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else Inception_V3_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedInceptionV3Weights.verify(weights) + weights = Inception_V3_QuantizedWeights.verify(weights) else: - weights = InceptionV3Weights.verify(weights) + weights = Inception_V3_Weights.verify(weights) original_aux_logits = kwargs.get("aux_logits", False) if weights is not None: diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index 8c5c7fbf5b2..c5afd731fad 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -10,21 +10,21 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..mobilenetv2 import MobileNetV2Weights +from ..mobilenetv2 import MobileNet_V2_Weights __all__ = [ "QuantizableMobileNetV2", - "QuantizedMobileNetV2Weights", + "MobileNet_V2_QuantizedWeights", "mobilenet_v2", ] -class QuantizedMobileNetV2Weights(Weights): - ImageNet1K_QNNPACK_RefV1 = WeightEntry( +class MobileNet_V2_QuantizedWeights(WeightsEnum): + ImageNet1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -34,7 +34,7 @@ class QuantizedMobileNetV2Weights(Weights): "backend": "qnnpack", "quantization": "qat", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", - "unquantized": MobileNetV2Weights.ImageNet1K_RefV1, + "unquantized": MobileNet_V2_Weights.ImageNet1K_V1, "acc@1": 71.658, "acc@5": 90.150, }, @@ -43,7 +43,7 @@ class QuantizedMobileNetV2Weights(Weights): def mobilenet_v2( - weights: Optional[Union[QuantizedMobileNetV2Weights, MobileNetV2Weights]] = None, + weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -52,13 +52,13 @@ def mobilenet_v2( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1 if quantize else MobileNetV2Weights.ImageNet1K_RefV1 + MobileNet_V2_QuantizedWeights.ImageNet1K_QNNPACK_V1 if quantize else MobileNet_V2_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedMobileNetV2Weights.verify(weights) + weights = MobileNet_V2_QuantizedWeights.verify(weights) else: - weights = MobileNetV2Weights.verify(weights) + weights = MobileNet_V2_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index dd293fd4080..a29e3f44697 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -11,15 +11,15 @@ QuantizableMobileNetV3, _replace_relu, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf +from ..mobilenetv3 import MobileNet_V3_Large_Weights, _mobilenet_v3_conf __all__ = [ "QuantizableMobileNetV3", - "QuantizedMobileNetV3LargeWeights", + "MobileNet_V3_Large_QuantizedWeights", "mobilenet_v3_large", ] @@ -27,7 +27,7 @@ def _mobilenet_v3_model( inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, @@ -56,8 +56,8 @@ def _mobilenet_v3_model( return model -class QuantizedMobileNetV3LargeWeights(Weights): - ImageNet1K_QNNPACK_RefV1 = WeightEntry( +class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): + ImageNet1K_QNNPACK_V1 = Weights( url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -67,7 +67,7 @@ class QuantizedMobileNetV3LargeWeights(Weights): "backend": "qnnpack", "quantization": "qat", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", - "unquantized": MobileNetV3LargeWeights.ImageNet1K_RefV1, + "unquantized": MobileNet_V3_Large_Weights.ImageNet1K_V1, "acc@1": 73.004, "acc@5": 90.858, }, @@ -76,7 +76,7 @@ class QuantizedMobileNetV3LargeWeights(Weights): def mobilenet_v3_large( - weights: Optional[Union[QuantizedMobileNetV3LargeWeights, MobileNetV3LargeWeights]] = None, + weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -85,15 +85,15 @@ def mobilenet_v3_large( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1 + MobileNet_V3_Large_QuantizedWeights.ImageNet1K_QNNPACK_V1 if quantize - else MobileNetV3LargeWeights.ImageNet1K_RefV1 + else MobileNet_V3_Large_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedMobileNetV3LargeWeights.verify(weights) + weights = MobileNet_V3_Large_QuantizedWeights.verify(weights) else: - weights = MobileNetV3LargeWeights.verify(weights) + weights = MobileNet_V3_Large_Weights.verify(weights) inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 744b57d706a..0de4eb5557b 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -11,17 +11,17 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights +from ..resnet import ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights __all__ = [ "QuantizableResNet", - "QuantizedResNet18Weights", - "QuantizedResNet50Weights", - "QuantizedResNeXt101_32x8dWeights", + "ResNet18_QuantizedWeights", + "ResNet50_QuantizedWeights", + "ResNeXt101_32X8D_QuantizedWeights", "resnet18", "resnet50", "resnext101_32x8d", @@ -31,7 +31,7 @@ def _resnet( block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]], layers: List[int], - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, @@ -63,13 +63,13 @@ def _resnet( } -class QuantizedResNet18Weights(Weights): - ImageNet1K_FBGEMM_RefV1 = WeightEntry( +class ResNet18_QuantizedWeights(WeightsEnum): + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ResNet18Weights.ImageNet1K_RefV1, + "unquantized": ResNet18_Weights.ImageNet1K_V1, "acc@1": 69.494, "acc@5": 88.882, }, @@ -77,24 +77,24 @@ class QuantizedResNet18Weights(Weights): ) -class QuantizedResNet50Weights(Weights): - ImageNet1K_FBGEMM_RefV1 = WeightEntry( +class ResNet50_QuantizedWeights(WeightsEnum): + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ResNet50Weights.ImageNet1K_RefV1, + "unquantized": ResNet50_Weights.ImageNet1K_V1, "acc@1": 75.920, "acc@5": 92.814, }, default=False, ) - ImageNet1K_FBGEMM_RefV2 = WeightEntry( + ImageNet1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "unquantized": ResNet50Weights.ImageNet1K_RefV2, + "unquantized": ResNet50_Weights.ImageNet1K_V2, "acc@1": 80.282, "acc@5": 94.976, }, @@ -102,24 +102,24 @@ class QuantizedResNet50Weights(Weights): ) -class QuantizedResNeXt101_32x8dWeights(Weights): - ImageNet1K_FBGEMM_RefV1 = WeightEntry( +class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV1, + "unquantized": ResNeXt101_32X8D_Weights.ImageNet1K_V1, "acc@1": 78.986, "acc@5": 94.480, }, default=False, ) - ImageNet1K_FBGEMM_RefV2 = WeightEntry( + ImageNet1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "unquantized": ResNeXt101_32x8dWeights.ImageNet1K_RefV2, + "unquantized": ResNeXt101_32X8D_Weights.ImageNet1K_V2, "acc@1": 82.574, "acc@5": 96.132, }, @@ -128,7 +128,7 @@ class QuantizedResNeXt101_32x8dWeights(Weights): def resnet18( - weights: Optional[Union[QuantizedResNet18Weights, ResNet18Weights]] = None, + weights: Optional[Union[ResNet18_QuantizedWeights, ResNet18_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -136,20 +136,18 @@ def resnet18( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = ( - QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1 - ) + default_value = ResNet18_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet18_Weights.ImageNet1K_V1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedResNet18Weights.verify(weights) + weights = ResNet18_QuantizedWeights.verify(weights) else: - weights = ResNet18Weights.verify(weights) + weights = ResNet18_Weights.verify(weights) return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs) def resnet50( - weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None, + weights: Optional[Union[ResNet50_QuantizedWeights, ResNet50_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -157,20 +155,18 @@ def resnet50( if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - default_value = ( - QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1 - ) + default_value = ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize else ResNet50_Weights.ImageNet1K_V1 weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedResNet50Weights.verify(weights) + weights = ResNet50_QuantizedWeights.verify(weights) else: - weights = ResNet50Weights.verify(weights) + weights = ResNet50_Weights.verify(weights) return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) def resnext101_32x8d( - weights: Optional[Union[QuantizedResNeXt101_32x8dWeights, ResNeXt101_32x8dWeights]] = None, + weights: Optional[Union[ResNeXt101_32X8D_QuantizedWeights, ResNeXt101_32X8D_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -179,15 +175,15 @@ def resnext101_32x8d( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1 + ResNeXt101_32X8D_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize - else ResNeXt101_32x8dWeights.ImageNet1K_RefV1 + else ResNeXt101_32X8D_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedResNeXt101_32x8dWeights.verify(weights) + weights = ResNeXt101_32X8D_QuantizedWeights.verify(weights) else: - weights = ResNeXt101_32x8dWeights.verify(weights) + weights = ResNeXt101_32X8D_Weights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "width_per_group", 8) diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index d9aade357b0..6677983a1d9 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -9,16 +9,16 @@ _replace_relu, quantize_model, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _IMAGENET_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights +from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights __all__ = [ "QuantizableShuffleNetV2", - "QuantizedShuffleNetV2_x0_5Weights", - "QuantizedShuffleNetV2_x1_0Weights", + "ShuffleNet_V2_X0_5_QuantizedWeights", + "ShuffleNet_V2_X1_0_QuantizedWeights", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", ] @@ -27,7 +27,7 @@ def _shufflenetv2( stages_repeats: List[int], stages_out_channels: List[int], - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, quantize: bool, **kwargs: Any, @@ -59,13 +59,13 @@ def _shufflenetv2( } -class QuantizedShuffleNetV2_x0_5Weights(Weights): - ImageNet1K_FBGEMM_Community = WeightEntry( +class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ShuffleNetV2_x0_5Weights.ImageNet1K_Community, + "unquantized": ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1, "acc@1": 57.972, "acc@5": 79.780, }, @@ -73,13 +73,13 @@ class QuantizedShuffleNetV2_x0_5Weights(Weights): ) -class QuantizedShuffleNetV2_x1_0Weights(Weights): - ImageNet1K_FBGEMM_Community = WeightEntry( +class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): + ImageNet1K_FBGEMM_V1 = Weights( url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ **_COMMON_META, - "unquantized": ShuffleNetV2_x1_0Weights.ImageNet1K_Community, + "unquantized": ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1, "acc@1": 68.360, "acc@5": 87.582, }, @@ -88,7 +88,7 @@ class QuantizedShuffleNetV2_x1_0Weights(Weights): def shufflenet_v2_x0_5( - weights: Optional[Union[QuantizedShuffleNetV2_x0_5Weights, ShuffleNetV2_x0_5Weights]] = None, + weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -97,21 +97,21 @@ def shufflenet_v2_x0_5( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedShuffleNetV2_x0_5Weights.ImageNet1K_FBGEMM_Community + ShuffleNet_V2_X0_5_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize - else ShuffleNetV2_x0_5Weights.ImageNet1K_Community + else ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedShuffleNetV2_x0_5Weights.verify(weights) + weights = ShuffleNet_V2_X0_5_QuantizedWeights.verify(weights) else: - weights = ShuffleNetV2_x0_5Weights.verify(weights) + weights = ShuffleNet_V2_X0_5_Weights.verify(weights) return _shufflenetv2([4, 8, 4], [24, 48, 96, 192, 1024], weights, progress, quantize, **kwargs) def shufflenet_v2_x1_0( - weights: Optional[Union[QuantizedShuffleNetV2_x1_0Weights, ShuffleNetV2_x1_0Weights]] = None, + weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None, progress: bool = True, quantize: bool = False, **kwargs: Any, @@ -120,14 +120,14 @@ def shufflenet_v2_x1_0( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: default_value = ( - QuantizedShuffleNetV2_x1_0Weights.ImageNet1K_FBGEMM_Community + ShuffleNet_V2_X1_0_QuantizedWeights.ImageNet1K_FBGEMM_V1 if quantize - else ShuffleNetV2_x1_0Weights.ImageNet1K_Community + else ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1 ) weights = _deprecated_param(kwargs, "pretrained", "weights", default_value) # type: ignore[assignment] if quantize: - weights = QuantizedShuffleNetV2_x1_0Weights.verify(weights) + weights = ShuffleNet_V2_X1_0_QuantizedWeights.verify(weights) else: - weights = ShuffleNetV2_x1_0Weights.verify(weights) + weights = ShuffleNet_V2_X1_0_Weights.verify(weights) return _shufflenetv2([4, 8, 4], [24, 116, 232, 464, 1024], weights, progress, quantize, **kwargs) diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index b89d882cc0b..1e12ae7bbd2 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -6,27 +6,27 @@ from torchvision.transforms.functional import InterpolationMode from ...models.regnet import RegNet, BlockParams -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param __all__ = [ "RegNet", - "RegNet_y_400mfWeights", - "RegNet_y_800mfWeights", - "RegNet_y_1_6gfWeights", - "RegNet_y_3_2gfWeights", - "RegNet_y_8gfWeights", - "RegNet_y_16gfWeights", - "RegNet_y_32gfWeights", - "RegNet_x_400mfWeights", - "RegNet_x_800mfWeights", - "RegNet_x_1_6gfWeights", - "RegNet_x_3_2gfWeights", - "RegNet_x_8gfWeights", - "RegNet_x_16gfWeights", - "RegNet_x_32gfWeights", + "RegNet_Y_400MF_Weights", + "RegNet_Y_800MF_Weights", + "RegNet_Y_1_6GF_Weights", + "RegNet_Y_3_2GF_Weights", + "RegNet_Y_8GF_Weights", + "RegNet_Y_16GF_Weights", + "RegNet_Y_32GF_Weights", + "RegNet_X_400MF_Weights", + "RegNet_X_800MF_Weights", + "RegNet_X_1_6GF_Weights", + "RegNet_X_3_2GF_Weights", + "RegNet_X_8GF_Weights", + "RegNet_X_16GF_Weights", + "RegNet_X_32GF_Weights", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", @@ -48,7 +48,7 @@ def _regnet( block_params: BlockParams, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> RegNet: @@ -64,8 +64,8 @@ def _regnet( return model -class RegNet_y_400mfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_Y_400MF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -78,8 +78,8 @@ class RegNet_y_400mfWeights(Weights): ) -class RegNet_y_800mfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_Y_800MF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -92,8 +92,8 @@ class RegNet_y_800mfWeights(Weights): ) -class RegNet_y_1_6gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_Y_1_6GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -106,8 +106,8 @@ class RegNet_y_1_6gfWeights(Weights): ) -class RegNet_y_3_2gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_Y_3_2GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -120,8 +120,8 @@ class RegNet_y_3_2gfWeights(Weights): ) -class RegNet_y_8gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_Y_8GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -134,8 +134,8 @@ class RegNet_y_8gfWeights(Weights): ) -class RegNet_y_16gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_Y_16GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -148,8 +148,8 @@ class RegNet_y_16gfWeights(Weights): ) -class RegNet_y_32gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_Y_32GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -162,8 +162,8 @@ class RegNet_y_32gfWeights(Weights): ) -class RegNet_x_400mfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_X_400MF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -176,8 +176,8 @@ class RegNet_x_400mfWeights(Weights): ) -class RegNet_x_800mfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_X_800MF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -190,8 +190,8 @@ class RegNet_x_800mfWeights(Weights): ) -class RegNet_x_1_6gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_X_1_6GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -204,8 +204,8 @@ class RegNet_x_1_6gfWeights(Weights): ) -class RegNet_x_3_2gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_X_3_2GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -218,8 +218,8 @@ class RegNet_x_3_2gfWeights(Weights): ) -class RegNet_x_8gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_X_8GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -232,8 +232,8 @@ class RegNet_x_8gfWeights(Weights): ) -class RegNet_x_16gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_X_16GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -246,8 +246,8 @@ class RegNet_x_16gfWeights(Weights): ) -class RegNet_x_32gfWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class RegNet_X_32GF_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -260,34 +260,34 @@ class RegNet_x_32gfWeights(Weights): ) -def regnet_y_400mf(weights: Optional[RegNet_y_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_400mfWeights.ImageNet1K_RefV1) - weights = RegNet_y_400mfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_400MF_Weights.ImageNet1K_V1) + weights = RegNet_Y_400MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_y_800mf(weights: Optional[RegNet_y_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_800mf(weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_800mfWeights.ImageNet1K_RefV1) - weights = RegNet_y_800mfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_800MF_Weights.ImageNet1K_V1) + weights = RegNet_Y_800MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_1_6gf(weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_1_6gfWeights.ImageNet1K_RefV1) - weights = RegNet_y_1_6gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_1_6GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_1_6GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs @@ -295,12 +295,12 @@ def regnet_y_1_6gf(weights: Optional[RegNet_y_1_6gfWeights] = None, progress: bo return _regnet(params, weights, progress, **kwargs) -def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_3_2gf(weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_3_2gfWeights.ImageNet1K_RefV1) - weights = RegNet_y_3_2gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_3_2GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_3_2GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs @@ -308,12 +308,12 @@ def regnet_y_3_2gf(weights: Optional[RegNet_y_3_2gfWeights] = None, progress: bo return _regnet(params, weights, progress, **kwargs) -def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_8gf(weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_8gfWeights.ImageNet1K_RefV1) - weights = RegNet_y_8gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_8GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_8GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs @@ -321,12 +321,12 @@ def regnet_y_8gf(weights: Optional[RegNet_y_8gfWeights] = None, progress: bool = return _regnet(params, weights, progress, **kwargs) -def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_16gf(weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_16gfWeights.ImageNet1K_RefV1) - weights = RegNet_y_16gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_16GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_16GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs @@ -334,12 +334,12 @@ def regnet_y_16gf(weights: Optional[RegNet_y_16gfWeights] = None, progress: bool return _regnet(params, weights, progress, **kwargs) -def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_y_32gf(weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_y_32gfWeights.ImageNet1K_RefV1) - weights = RegNet_y_32gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_Y_32GF_Weights.ImageNet1K_V1) + weights = RegNet_Y_32GF_Weights.verify(weights) params = BlockParams.from_init_params( depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs @@ -347,78 +347,78 @@ def regnet_y_32gf(weights: Optional[RegNet_y_32gfWeights] = None, progress: bool return _regnet(params, weights, progress, **kwargs) -def regnet_x_400mf(weights: Optional[RegNet_x_400mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_400mf(weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_400mfWeights.ImageNet1K_RefV1) - weights = RegNet_x_400mfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_400MF_Weights.ImageNet1K_V1) + weights = RegNet_X_400MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_800mf(weights: Optional[RegNet_x_800mfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_800mf(weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_800mfWeights.ImageNet1K_RefV1) - weights = RegNet_x_800mfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_800MF_Weights.ImageNet1K_V1) + weights = RegNet_X_800MF_Weights.verify(weights) params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_1_6gf(weights: Optional[RegNet_x_1_6gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_1_6gf(weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_1_6gfWeights.ImageNet1K_RefV1) - weights = RegNet_x_1_6gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_1_6GF_Weights.ImageNet1K_V1) + weights = RegNet_X_1_6GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_3_2gf(weights: Optional[RegNet_x_3_2gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_3_2gf(weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_3_2gfWeights.ImageNet1K_RefV1) - weights = RegNet_x_3_2gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_3_2GF_Weights.ImageNet1K_V1) + weights = RegNet_X_3_2GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_8gf(weights: Optional[RegNet_x_8gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_8gf(weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_8gfWeights.ImageNet1K_RefV1) - weights = RegNet_x_8gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_8GF_Weights.ImageNet1K_V1) + weights = RegNet_X_8GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_16gf(weights: Optional[RegNet_x_16gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_16gf(weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_16gfWeights.ImageNet1K_RefV1) - weights = RegNet_x_16gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_16GF_Weights.ImageNet1K_V1) + weights = RegNet_X_16GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) return _regnet(params, weights, progress, **kwargs) -def regnet_x_32gf(weights: Optional[RegNet_x_32gfWeights] = None, progress: bool = True, **kwargs: Any) -> RegNet: +def regnet_x_32gf(weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_x_32gfWeights.ImageNet1K_RefV1) - weights = RegNet_x_32gfWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", RegNet_X_32GF_Weights.ImageNet1K_V1) + weights = RegNet_X_32GF_Weights.verify(weights) params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) return _regnet(params, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index 0ff4436bf63..e213864acbe 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -5,22 +5,22 @@ from torchvision.transforms.functional import InterpolationMode from ...models.resnet import BasicBlock, Bottleneck, ResNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param __all__ = [ "ResNet", - "ResNet18Weights", - "ResNet34Weights", - "ResNet50Weights", - "ResNet101Weights", - "ResNet152Weights", - "ResNeXt50_32x4dWeights", - "ResNeXt101_32x8dWeights", - "WideResNet50_2Weights", - "WideResNet101_2Weights", + "ResNet18_Weights", + "ResNet34_Weights", + "ResNet50_Weights", + "ResNet101_Weights", + "ResNet152_Weights", + "ResNeXt50_32X4D_Weights", + "ResNeXt101_32X8D_Weights", + "Wide_ResNet50_2_Weights", + "Wide_ResNet101_2_Weights", "resnet18", "resnet34", "resnet50", @@ -36,7 +36,7 @@ def _resnet( block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> ResNet: @@ -54,8 +54,8 @@ def _resnet( _COMMON_META = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} -class ResNet18Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet18_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet18-f37072fd.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -68,8 +68,8 @@ class ResNet18Weights(Weights): ) -class ResNet34Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet34_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet34-b627a593.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -82,8 +82,8 @@ class ResNet34Weights(Weights): ) -class ResNet50Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet50_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet50-0676ba61.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -94,7 +94,7 @@ class ResNet50Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet50-f46c3f97.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -107,8 +107,8 @@ class ResNet50Weights(Weights): ) -class ResNet101Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet101_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet101-63fe2227.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -119,7 +119,7 @@ class ResNet101Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -132,8 +132,8 @@ class ResNet101Weights(Weights): ) -class ResNet152Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNet152_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnet152-394f9c45.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -144,7 +144,7 @@ class ResNet152Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -157,8 +157,8 @@ class ResNet152Weights(Weights): ) -class ResNeXt50_32x4dWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNeXt50_32X4D_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -169,7 +169,7 @@ class ResNeXt50_32x4dWeights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -182,8 +182,8 @@ class ResNeXt50_32x4dWeights(Weights): ) -class ResNeXt101_32x8dWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class ResNeXt101_32X8D_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -194,7 +194,7 @@ class ResNeXt101_32x8dWeights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -207,8 +207,8 @@ class ResNeXt101_32x8dWeights(Weights): ) -class WideResNet50_2Weights(Weights): - ImageNet1K_Community = WeightEntry( +class Wide_ResNet50_2_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -219,7 +219,7 @@ class WideResNet50_2Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -232,8 +232,8 @@ class WideResNet50_2Weights(Weights): ) -class WideResNet101_2Weights(Weights): - ImageNet1K_Community = WeightEntry( +class Wide_ResNet101_2_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -244,7 +244,7 @@ class WideResNet101_2Weights(Weights): }, default=False, ) - ImageNet1K_RefV2 = WeightEntry( + ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", transforms=partial(ImageNetEval, crop_size=224, resize_size=232), meta={ @@ -257,97 +257,101 @@ class WideResNet101_2Weights(Weights): ) -def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18Weights.ImageNet1K_RefV1) - weights = ResNet18Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet18_Weights.ImageNet1K_V1) + weights = ResNet18_Weights.verify(weights) return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) -def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet34(weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34Weights.ImageNet1K_RefV1) - weights = ResNet34Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet34_Weights.ImageNet1K_V1) + weights = ResNet34_Weights.verify(weights) return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet50(weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50Weights.ImageNet1K_RefV1) - weights = ResNet50Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet50_Weights.ImageNet1K_V1) + weights = ResNet50_Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet101(weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101Weights.ImageNet1K_RefV1) - weights = ResNet101Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet101_Weights.ImageNet1K_V1) + weights = ResNet101_Weights.verify(weights) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnet152(weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152Weights.ImageNet1K_RefV1) - weights = ResNet152Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNet152_Weights.ImageNet1K_V1) + weights = ResNet152_Weights.verify(weights) return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) -def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnext50_32x4d(weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32x4dWeights.ImageNet1K_RefV1) - weights = ResNeXt50_32x4dWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt50_32X4D_Weights.ImageNet1K_V1) + weights = ResNeXt50_32X4D_Weights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "width_per_group", 4) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def resnext101_32x8d( + weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32x8dWeights.ImageNet1K_RefV1) - weights = ResNeXt101_32x8dWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ResNeXt101_32X8D_Weights.ImageNet1K_V1) + weights = ResNeXt101_32X8D_Weights.verify(weights) _ovewrite_named_param(kwargs, "groups", 32) _ovewrite_named_param(kwargs, "width_per_group", 8) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) -def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def wide_resnet50_2(weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet50_2Weights.ImageNet1K_Community) - weights = WideResNet50_2Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet50_2_Weights.ImageNet1K_V1) + weights = Wide_ResNet50_2_Weights.verify(weights) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) -def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: +def wide_resnet101_2( + weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", WideResNet101_2Weights.ImageNet1K_Community) - weights = WideResNet101_2Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", Wide_ResNet101_2_Weights.ImageNet1K_V1) + weights = Wide_ResNet101_2_Weights.verify(weights) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 93ff73f0032..638b771c333 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -5,19 +5,19 @@ from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from ..resnet import resnet50, resnet101 -from ..resnet import ResNet50Weights, ResNet101Weights +from ..resnet import ResNet50_Weights, ResNet101_Weights __all__ = [ "DeepLabV3", - "DeepLabV3ResNet50Weights", - "DeepLabV3ResNet101Weights", - "DeepLabV3MobileNetV3LargeWeights", + "DeepLabV3_ResNet50_Weights", + "DeepLabV3_ResNet101_Weights", + "DeepLabV3_MobileNet_V3_Large_Weights", "deeplabv3_mobilenet_v3_large", "deeplabv3_resnet50", "deeplabv3_resnet101", @@ -30,8 +30,8 @@ } -class DeepLabV3ResNet50Weights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class DeepLabV3_ResNet50_Weights(WeightsEnum): + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -44,8 +44,8 @@ class DeepLabV3ResNet50Weights(Weights): ) -class DeepLabV3ResNet101Weights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class DeepLabV3_ResNet101_Weights(WeightsEnum): + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -58,8 +58,8 @@ class DeepLabV3ResNet101Weights(Weights): ) -class DeepLabV3MobileNetV3LargeWeights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -73,25 +73,25 @@ class DeepLabV3MobileNetV3LargeWeights(Weights): def deeplabv3_resnet50( - weights: Optional[DeepLabV3ResNet50Weights] = None, + weights: Optional[DeepLabV3_ResNet50_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, **kwargs: Any, ) -> DeepLabV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet50Weights.CocoWithVocLabels_RefV1) - weights = DeepLabV3ResNet50Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet50_Weights.CocoWithVocLabels_V1) + weights = DeepLabV3_ResNet50_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -110,25 +110,25 @@ def deeplabv3_resnet50( def deeplabv3_resnet101( - weights: Optional[DeepLabV3ResNet101Weights] = None, + weights: Optional[DeepLabV3_ResNet101_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101Weights] = None, + weights_backbone: Optional[ResNet101_Weights] = None, **kwargs: Any, ) -> DeepLabV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3ResNet101Weights.CocoWithVocLabels_RefV1) - weights = DeepLabV3ResNet101Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", DeepLabV3_ResNet101_Weights.CocoWithVocLabels_V1) + weights = DeepLabV3_ResNet101_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet101Weights.verify(weights_backbone) + weights_backbone = ResNet101_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -147,27 +147,27 @@ def deeplabv3_resnet101( def deeplabv3_mobilenet_v3_large( - weights: Optional[DeepLabV3MobileNetV3LargeWeights] = None, + weights: Optional[DeepLabV3_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, **kwargs: Any, ) -> DeepLabV3: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param( - kwargs, "pretrained", "weights", DeepLabV3MobileNetV3LargeWeights.CocoWithVocLabels_RefV1 + kwargs, "pretrained", "weights", DeepLabV3_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1 ) - weights = DeepLabV3MobileNetV3LargeWeights.verify(weights) + weights = DeepLabV3_MobileNet_V3_Large_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 138552d3aa0..841e2ea95c5 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -5,13 +5,13 @@ from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.fcn import FCN, _fcn_resnet -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 +from ..resnet import ResNet50_Weights, ResNet101_Weights, resnet50, resnet101 -__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"] +__all__ = ["FCN", "FCN_ResNet50_Weights", "FCN_ResNet101_Weights", "fcn_resnet50", "fcn_resnet101"] _COMMON_META = { @@ -20,8 +20,8 @@ } -class FCNResNet50Weights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class FCN_ResNet50_Weights(WeightsEnum): + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -34,8 +34,8 @@ class FCNResNet50Weights(Weights): ) -class FCNResNet101Weights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class FCN_ResNet101_Weights(WeightsEnum): + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -49,25 +49,25 @@ class FCNResNet101Weights(Weights): def fcn_resnet50( - weights: Optional[FCNResNet50Weights] = None, + weights: Optional[FCN_ResNet50_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet50Weights] = None, + weights_backbone: Optional[ResNet50_Weights] = None, **kwargs: Any, ) -> FCN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet50Weights.CocoWithVocLabels_RefV1) - weights = FCNResNet50Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet50_Weights.CocoWithVocLabels_V1) + weights = FCN_ResNet50_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet50Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet50_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet50Weights.verify(weights_backbone) + weights_backbone = ResNet50_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None @@ -86,25 +86,25 @@ def fcn_resnet50( def fcn_resnet101( - weights: Optional[FCNResNet101Weights] = None, + weights: Optional[FCN_ResNet101_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, aux_loss: Optional[bool] = None, - weights_backbone: Optional[ResNet101Weights] = None, + weights_backbone: Optional[ResNet101_Weights] = None, **kwargs: Any, ) -> FCN: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", FCNResNet101Weights.CocoWithVocLabels_RefV1) - weights = FCNResNet101Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", FCN_ResNet101_Weights.CocoWithVocLabels_V1) + weights = FCN_ResNet101_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", ResNet101Weights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", ResNet101_Weights.ImageNet1K_V1 ) - weights_backbone = ResNet101Weights.verify(weights_backbone) + weights_backbone = ResNet101_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 260093efc0f..9743e02fa16 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -5,17 +5,17 @@ from torchvision.transforms.functional import InterpolationMode from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _VOC_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_value_param -from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -__all__ = ["LRASPP", "LRASPPMobileNetV3LargeWeights", "lraspp_mobilenet_v3_large"] +__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"] -class LRASPPMobileNetV3LargeWeights(Weights): - CocoWithVocLabels_RefV1 = WeightEntry( +class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): + CocoWithVocLabels_V1 = Weights( url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", transforms=partial(VocEval, resize_size=520), meta={ @@ -30,10 +30,10 @@ class LRASPPMobileNetV3LargeWeights(Weights): def lraspp_mobilenet_v3_large( - weights: Optional[LRASPPMobileNetV3LargeWeights] = None, + weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None, progress: bool = True, num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNetV3LargeWeights] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, **kwargs: Any, ) -> LRASPP: if kwargs.pop("aux_loss", False): @@ -43,16 +43,16 @@ def lraspp_mobilenet_v3_large( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param( - kwargs, "pretrained", "weights", LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1 + kwargs, "pretrained", "weights", LRASPP_MobileNet_V3_Large_Weights.CocoWithVocLabels_V1 ) - weights = LRASPPMobileNetV3LargeWeights.verify(weights) + weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights) if type(weights_backbone) == bool and weights_backbone: _deprecated_positional(kwargs, "pretrained_backbone", "weights_backbone", True) if "pretrained_backbone" in kwargs: weights_backbone = _deprecated_param( - kwargs, "pretrained_backbone", "weights_backbone", MobileNetV3LargeWeights.ImageNet1K_RefV1 + kwargs, "pretrained_backbone", "weights_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1 ) - weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) if weights is not None: weights_backbone = None diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index 006f4027766..9fa98c44223 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -5,17 +5,17 @@ from torchvision.transforms.functional import InterpolationMode from ...models.shufflenetv2 import ShuffleNetV2 -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param __all__ = [ "ShuffleNetV2", - "ShuffleNetV2_x0_5Weights", - "ShuffleNetV2_x1_0Weights", - "ShuffleNetV2_x1_5Weights", - "ShuffleNetV2_x2_0Weights", + "ShuffleNet_V2_X0_5_Weights", + "ShuffleNet_V2_X1_0_Weights", + "ShuffleNet_V2_X1_5_Weights", + "ShuffleNet_V2_X2_0_Weights", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", @@ -24,7 +24,7 @@ def _shufflenetv2( - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, *args: Any, **kwargs: Any, @@ -48,8 +48,8 @@ def _shufflenetv2( } -class ShuffleNetV2_x0_5Weights(Weights): - ImageNet1K_Community = WeightEntry( +class ShuffleNet_V2_X0_5_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -61,8 +61,8 @@ class ShuffleNetV2_x0_5Weights(Weights): ) -class ShuffleNetV2_x1_0Weights(Weights): - ImageNet1K_Community = WeightEntry( +class ShuffleNet_V2_X1_0_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -74,57 +74,57 @@ class ShuffleNetV2_x1_0Weights(Weights): ) -class ShuffleNetV2_x1_5Weights(Weights): +class ShuffleNet_V2_X1_5_Weights(WeightsEnum): pass -class ShuffleNetV2_x2_0Weights(Weights): +class ShuffleNet_V2_X2_0_Weights(WeightsEnum): pass def shufflenet_v2_x0_5( - weights: Optional[ShuffleNetV2_x0_5Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x0_5Weights.ImageNet1K_Community) - weights = ShuffleNetV2_x0_5Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X0_5_Weights.ImageNet1K_V1) + weights = ShuffleNet_V2_X0_5_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) def shufflenet_v2_x1_0( - weights: Optional[ShuffleNetV2_x1_0Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNetV2_x1_0Weights.ImageNet1K_Community) - weights = ShuffleNetV2_x1_0Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", ShuffleNet_V2_X1_0_Weights.ImageNet1K_V1) + weights = ShuffleNet_V2_X1_0_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) def shufflenet_v2_x1_5( - weights: Optional[ShuffleNetV2_x1_5Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = ShuffleNetV2_x1_5Weights.verify(weights) + weights = ShuffleNet_V2_X1_5_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) def shufflenet_v2_x2_0( - weights: Optional[ShuffleNetV2_x2_0Weights] = None, progress: bool = True, **kwargs: Any + weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any ) -> ShuffleNetV2: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = ShuffleNetV2_x2_0Weights.verify(weights) + weights = ShuffleNet_V2_X2_0_Weights.verify(weights) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index 34144d15213..fdfaa01e8be 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -5,12 +5,12 @@ from torchvision.transforms.functional import InterpolationMode from ...models.squeezenet import SqueezeNet -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param -__all__ = ["SqueezeNet", "SqueezeNet1_0Weights", "SqueezeNet1_1Weights", "squeezenet1_0", "squeezenet1_1"] +__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"] _COMMON_META = { @@ -21,8 +21,8 @@ } -class SqueezeNet1_0Weights(Weights): - ImageNet1K_Community = WeightEntry( +class SqueezeNet1_0_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -34,8 +34,8 @@ class SqueezeNet1_0Weights(Weights): ) -class SqueezeNet1_1Weights(Weights): - ImageNet1K_Community = WeightEntry( +class SqueezeNet1_1_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -47,12 +47,12 @@ class SqueezeNet1_1Weights(Weights): ) -def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: +def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0Weights.ImageNet1K_Community) - weights = SqueezeNet1_0Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_0_Weights.ImageNet1K_V1) + weights = SqueezeNet1_0_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) @@ -65,12 +65,12 @@ def squeezenet1_0(weights: Optional[SqueezeNet1_0Weights] = None, progress: bool return model -def squeezenet1_1(weights: Optional[SqueezeNet1_1Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: +def squeezenet1_1(weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1Weights.ImageNet1K_Community) - weights = SqueezeNet1_1Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", SqueezeNet1_1_Weights.ImageNet1K_V1) + weights = SqueezeNet1_1_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index 8c5a58cb592..a357426693d 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -5,21 +5,21 @@ from torchvision.transforms.functional import InterpolationMode from ...models.vgg import VGG, make_layers, cfgs -from ._api import Weights, WeightEntry +from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param __all__ = [ "VGG", - "VGG11Weights", - "VGG11BNWeights", - "VGG13Weights", - "VGG13BNWeights", - "VGG16Weights", - "VGG16BNWeights", - "VGG19Weights", - "VGG19BNWeights", + "VGG11_Weights", + "VGG11_BN_Weights", + "VGG13_Weights", + "VGG13_BN_Weights", + "VGG16_Weights", + "VGG16_BN_Weights", + "VGG19_Weights", + "VGG19_BN_Weights", "vgg11", "vgg11_bn", "vgg13", @@ -31,7 +31,7 @@ ] -def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG: +def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) @@ -48,8 +48,8 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, } -class VGG11Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG11_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11-8a719046.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -61,8 +61,8 @@ class VGG11Weights(Weights): ) -class VGG11BNWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG11_BN_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -74,8 +74,8 @@ class VGG11BNWeights(Weights): ) -class VGG13Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG13_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13-19584684.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -87,8 +87,8 @@ class VGG13Weights(Weights): ) -class VGG13BNWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG13_BN_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -100,8 +100,8 @@ class VGG13BNWeights(Weights): ) -class VGG16Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG16_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16-397923af.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -114,7 +114,7 @@ class VGG16Weights(Weights): # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # same input standardization method as the paper. Only the `features` weights have proper values, those on the # `classifier` module are filled with nans. - ImageNet1K_Features = WeightEntry( + ImageNet1K_Features = Weights( url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", transforms=partial( ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0) @@ -131,8 +131,8 @@ class VGG16Weights(Weights): ) -class VGG16BNWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG16_BN_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -144,8 +144,8 @@ class VGG16BNWeights(Weights): ) -class VGG19Weights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG19_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -157,8 +157,8 @@ class VGG19Weights(Weights): ) -class VGG19BNWeights(Weights): - ImageNet1K_RefV1 = WeightEntry( +class VGG19_BN_Weights(WeightsEnum): + ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", transforms=partial(ImageNetEval, crop_size=224), meta={ @@ -170,81 +170,81 @@ class VGG19BNWeights(Weights): ) -def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11Weights.ImageNet1K_RefV1) - weights = VGG11Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_Weights.ImageNet1K_V1) + weights = VGG11_Weights.verify(weights) return _vgg("A", False, weights, progress, **kwargs) -def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg11_bn(weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11BNWeights.ImageNet1K_RefV1) - weights = VGG11BNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG11_BN_Weights.ImageNet1K_V1) + weights = VGG11_BN_Weights.verify(weights) return _vgg("A", True, weights, progress, **kwargs) -def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg13(weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13Weights.ImageNet1K_RefV1) - weights = VGG13Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_Weights.ImageNet1K_V1) + weights = VGG13_Weights.verify(weights) return _vgg("B", False, weights, progress, **kwargs) -def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg13_bn(weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13BNWeights.ImageNet1K_RefV1) - weights = VGG13BNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG13_BN_Weights.ImageNet1K_V1) + weights = VGG13_BN_Weights.verify(weights) return _vgg("B", True, weights, progress, **kwargs) -def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg16(weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16Weights.ImageNet1K_RefV1) - weights = VGG16Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_Weights.ImageNet1K_V1) + weights = VGG16_Weights.verify(weights) return _vgg("D", False, weights, progress, **kwargs) -def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg16_bn(weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16BNWeights.ImageNet1K_RefV1) - weights = VGG16BNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG16_BN_Weights.ImageNet1K_V1) + weights = VGG16_BN_Weights.verify(weights) return _vgg("D", True, weights, progress, **kwargs) -def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg19(weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19Weights.ImageNet1K_RefV1) - weights = VGG19Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_Weights.ImageNet1K_V1) + weights = VGG19_Weights.verify(weights) return _vgg("E", False, weights, progress, **kwargs) -def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: +def vgg19_bn(weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19BNWeights.ImageNet1K_RefV1) - weights = VGG19BNWeights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", VGG19_BN_Weights.ImageNet1K_V1) + weights = VGG19_BN_Weights.verify(weights) return _vgg("E", True, weights, progress, **kwargs) diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index 2dcfdfba2c0..c75f618a8b1 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -15,16 +15,16 @@ R2Plus1dStem, VideoResNet, ) -from .._api import Weights, WeightEntry +from .._api import WeightsEnum, Weights from .._meta import _KINETICS400_CATEGORIES from .._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param __all__ = [ "VideoResNet", - "R3D_18Weights", - "MC3_18Weights", - "R2Plus1D_18Weights", + "R3D_18_Weights", + "MC3_18_Weights", + "R2Plus1D_18_Weights", "r3d_18", "mc3_18", "r2plus1d_18", @@ -36,7 +36,7 @@ def _video_resnet( conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], layers: List[int], stem: Callable[..., nn.Module], - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> VideoResNet: @@ -59,8 +59,8 @@ def _video_resnet( } -class R3D_18Weights(Weights): - Kinetics400_RefV1 = WeightEntry( +class R3D_18_Weights(WeightsEnum): + Kinetics400_V1 = Weights( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ @@ -72,8 +72,8 @@ class R3D_18Weights(Weights): ) -class MC3_18Weights(Weights): - Kinetics400_RefV1 = WeightEntry( +class MC3_18_Weights(WeightsEnum): + Kinetics400_V1 = Weights( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ @@ -85,8 +85,8 @@ class MC3_18Weights(Weights): ) -class R2Plus1D_18Weights(Weights): - Kinetics400_RefV1 = WeightEntry( +class R2Plus1D_18_Weights(WeightsEnum): + Kinetics400_V1 = Weights( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ @@ -98,12 +98,12 @@ class R2Plus1D_18Weights(Weights): ) -def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: +def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18Weights.Kinetics400_RefV1) - weights = R3D_18Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", R3D_18_Weights.Kinetics400_V1) + weights = R3D_18_Weights.verify(weights) return _video_resnet( BasicBlock, @@ -116,12 +116,12 @@ def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwa ) -def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: +def mc3_18(weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18Weights.Kinetics400_RefV1) - weights = MC3_18Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", MC3_18_Weights.Kinetics400_V1) + weights = MC3_18_Weights.verify(weights) return _video_resnet( BasicBlock, @@ -134,12 +134,12 @@ def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwa ) -def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: +def r2plus1d_18(weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: if type(weights) == bool and weights: _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: - weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18Weights.Kinetics400_RefV1) - weights = R2Plus1D_18Weights.verify(weights) + weights = _deprecated_param(kwargs, "pretrained", "weights", R2Plus1D_18_Weights.Kinetics400_V1) + weights = R2Plus1D_18_Weights.verify(weights) return _video_resnet( BasicBlock, diff --git a/torchvision/prototype/models/vision_transformer.py b/torchvision/prototype/models/vision_transformer.py index 987f3af1bb4..bbe5aba262c 100644 --- a/torchvision/prototype/models/vision_transformer.py +++ b/torchvision/prototype/models/vision_transformer.py @@ -11,16 +11,16 @@ import torch.nn as nn from torch import Tensor -from ._api import Weights +from ._api import WeightsEnum from ._utils import _deprecated_param, _deprecated_positional __all__ = [ "VisionTransformer", - "VisionTransformer_B_16Weights", - "VisionTransformer_B_32Weights", - "VisionTransformer_L_16Weights", - "VisionTransformer_L_32Weights", + "ViT_B_16_Weights", + "ViT_B_32_Weights", + "ViT_L_16_Weights", + "ViT_L_32_Weights", "vit_b_16", "vit_b_32", "vit_l_16", @@ -231,22 +231,22 @@ def forward(self, x: torch.Tensor): return x -class VisionTransformer_B_16Weights(Weights): +class ViT_B_16_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_b_16 pass -class VisionTransformer_B_32Weights(Weights): +class ViT_B_32_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_b_32 pass -class VisionTransformer_L_16Weights(Weights): +class ViT_L_16_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_l_16 pass -class VisionTransformer_L_32Weights(Weights): +class ViT_L_32_Weights(WeightsEnum): # If a default model is added here the corresponding changes need to be done in vit_l_32 pass @@ -257,7 +257,7 @@ def _vision_transformer( num_heads: int, hidden_dim: int, mlp_dim: int, - weights: Optional[Weights], + weights: Optional[WeightsEnum], progress: bool, **kwargs: Any, ) -> VisionTransformer: @@ -279,15 +279,13 @@ def _vision_transformer( return model -def vit_b_16( - weights: Optional[VisionTransformer_B_16Weights] = None, progress: bool = True, **kwargs: Any -) -> VisionTransformer: +def vit_b_16(weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - weights (VisionTransformer_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet. + weights (ViT_B_16Weights, optional): If not None, returns a model pre-trained on ImageNet. Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ @@ -295,7 +293,7 @@ def vit_b_16( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = VisionTransformer_B_16Weights.verify(weights) + weights = ViT_B_16_Weights.verify(weights) return _vision_transformer( patch_size=16, @@ -309,15 +307,13 @@ def vit_b_16( ) -def vit_b_32( - weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any -) -> VisionTransformer: +def vit_b_32(weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_b_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - weights (VisionTransformer_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet. + weights (ViT_B_32Weights, optional): If not None, returns a model pre-trained on ImageNet. Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ @@ -325,7 +321,7 @@ def vit_b_32( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = VisionTransformer_B_32Weights.verify(weights) + weights = ViT_B_32_Weights.verify(weights) return _vision_transformer( patch_size=32, @@ -339,15 +335,13 @@ def vit_b_32( ) -def vit_l_16( - weights: Optional[VisionTransformer_L_16Weights] = None, progress: bool = True, **kwargs: Any -) -> VisionTransformer: +def vit_l_16(weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_16 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. + weights (ViT_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ @@ -355,7 +349,7 @@ def vit_l_16( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = VisionTransformer_L_16Weights.verify(weights) + weights = ViT_L_16_Weights.verify(weights) return _vision_transformer( patch_size=16, @@ -369,15 +363,13 @@ def vit_l_16( ) -def vit_l_32( - weights: Optional[VisionTransformer_B_32Weights] = None, progress: bool = True, **kwargs: Any -) -> VisionTransformer: +def vit_l_32(weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer: """ Constructs a vit_l_32 architecture from `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" `_. Args: - weights (VisionTransformer_L_16Weights, optional): If not None, returns a model pre-trained on ImageNet. + weights (ViT_L_32Weights, optional): If not None, returns a model pre-trained on ImageNet. Default: None. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True. """ @@ -385,7 +377,7 @@ def vit_l_32( _deprecated_positional(kwargs, "pretrained", "weights", True) if "pretrained" in kwargs: weights = _deprecated_param(kwargs, "pretrained", "weights", None) - weights = VisionTransformer_L_32Weights.verify(weights) + weights = ViT_L_32_Weights.verify(weights) return _vision_transformer( patch_size=32,