diff --git a/references/classification/train.py b/references/classification/train.py index b16ed3d2a42..b2c6844df9b 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -158,8 +158,7 @@ def load_data(traindir, valdir, args): crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) else: - fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model] - weights = PM._api.get_weight(fn, args.weights) + weights = PM.get_weight(args.weights) preprocessing = weights.transforms() dataset_test = torchvision.datasets.ImageFolder( diff --git a/references/detection/train.py b/references/detection/train.py index ae13a32bd22..0788895af20 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -53,8 +53,7 @@ def get_transform(train, args): elif not args.weights: return presets.DetectionPresetEval() else: - fn = PM.detection.__dict__[args.model] - weights = PM._api.get_weight(fn, args.weights) + weights = PM.get_weight(args.weights) return weights.transforms() diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 2dbb962fe2f..72a9bdb01f5 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -38,8 +38,7 @@ def get_transform(train, args): elif not args.weights: return presets.SegmentationPresetEval(base_size=520) else: - fn = PM.segmentation.__dict__[args.model] - weights = PM._api.get_weight(fn, args.weights) + weights = PM.get_weight(args.weights) return weights.transforms() diff --git a/references/video_classification/train.py b/references/video_classification/train.py index d66879e5b46..1f363f57dad 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -160,8 +160,7 @@ def main(args): if not args.weights: transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) else: - fn = PM.video.__dict__[args.model] - weights = PM._api.get_weight(fn, args.weights) + weights = PM.get_weight(args.weights) transform_test = weights.transforms() if args.cache_dataset and os.path.exists(cache_path): diff --git a/test/test_models.py b/test/test_models.py index 5fbe0dca38f..b5500ef08b4 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -22,7 +22,11 @@ def get_models_from_module(module): # TODO add a registration mechanism to torchvision.models - return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + return [ + v + for k, v in module.__dict__.items() + if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight" + ] @pytest.fixture diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 92a88342534..1dc883528ef 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -24,6 +24,19 @@ def _get_parent_module(model_fn): return module +def _get_model_weights(model_fn): + module = _get_parent_module(model_fn) + weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights" + try: + return next( + v + for k, v in module.__dict__.items() + if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__ + ) + except StopIteration: + return None + + def _build_model(fn, **kwargs): try: model = fn(**kwargs) @@ -36,24 +49,22 @@ def _build_model(fn, **kwargs): @pytest.mark.parametrize( - "model_fn, name, weight", + "name, weight", [ - (models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1), - (models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2), + ("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1), + ("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2), ( - models.quantization.resnet50, - "default", + "ResNet50_QuantizedWeights.default", models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2, ), ( - models.quantization.resnet50, - "ImageNet1K_FBGEMM_V1", + "ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1", models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1, ), ], ) -def test_get_weight(model_fn, name, weight): - assert models._api.get_weight(model_fn, name) == weight +def test_get_weight(name, weight): + assert models.get_weight(name) == weight @pytest.mark.parametrize( @@ -65,10 +76,9 @@ def test_get_weight(model_fn, name, weight): + TM.get_models_from_module(models.video), ) def test_naming_conventions(model_fn): - model_name = model_fn.__name__ - module = _get_parent_module(model_fn) - weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights" - assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name)) + weights_enum = _get_model_weights(model_fn) + assert weights_enum is not None + assert len(weights_enum) == 0 or hasattr(weights_enum, "default") @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index f675dc37f25..12a4738e53c 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -15,3 +15,4 @@ from . import quantization from . import segmentation from . import video +from ._api import get_weight diff --git a/torchvision/prototype/models/_api.py b/torchvision/prototype/models/_api.py index 2935039e087..1f66fd2be45 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/prototype/models/_api.py @@ -1,7 +1,9 @@ +import importlib +import inspect +import sys from collections import OrderedDict from dataclasses import dataclass, fields from enum import Enum -from inspect import signature from typing import Any, Callable, Dict from ..._internally_replaced_utils import load_state_dict_from_url @@ -30,7 +32,6 @@ class Weights: url: str transforms: Callable meta: Dict[str, Any] - default: bool class WeightsEnum(Enum): @@ -50,7 +51,7 @@ def __init__(self, value: Weights): def verify(cls, obj: Any) -> Any: if obj is not None: if type(obj) is str: - obj = cls.from_str(obj) + obj = cls.from_str(obj.replace(cls.__name__ + ".", "")) elif not isinstance(obj, cls): raise TypeError( f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." @@ -59,8 +60,8 @@ def verify(cls, obj: Any) -> Any: @classmethod def from_str(cls, value: str) -> "WeightsEnum": - for v in cls: - if v._name_ == value or (value == "default" and v.default): + for k, v in cls.__members__.items(): + if k == value: return v raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") @@ -78,41 +79,35 @@ def __getattr__(self, name): return super().__getattr__(name) -def get_weight(fn: Callable, weight_name: str) -> WeightsEnum: +def get_weight(name: str) -> WeightsEnum: """ - Gets the weight enum of a specific model builder method and weight name combination. + Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1" Args: - fn (Callable): The builder method used to create the model. - weight_name (str): The name of the weight enum entry of the specific model. + name (str): The name of the weight enum entry. Returns: WeightsEnum: The requested weight enum. """ - sig = signature(fn) - if "weights" not in sig.parameters: - raise ValueError("The method is missing the 'weights' parameter.") + try: + enum_name, value_name = name.split(".") + except ValueError: + raise ValueError(f"Invalid weight name provided: '{name}'.") + + base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1]) + base_module = importlib.import_module(base_module_name) + model_modules = [base_module] + [ + x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py") + ] - ann = signature(fn).parameters["weights"].annotation weights_enum = None - if isinstance(ann, type) and issubclass(ann, WeightsEnum): - weights_enum = ann - else: - # handle cases like Union[Optional, T] - # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 - for t in ann.__args__: # type: ignore[union-attr] - if isinstance(t, type) and issubclass(t, WeightsEnum): - # ensure the name exists. handles builders with multiple types of weights like in quantization - try: - t.from_str(weight_name) - except ValueError: - continue - weights_enum = t - break + for m in model_modules: + potential_class = m.__dict__.get(enum_name, None) + if potential_class is not None and issubclass(potential_class, WeightsEnum): + weights_enum = potential_class + break if weights_enum is None: - raise ValueError( - "The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct." - ) + raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") - return weights_enum.from_str(weight_name) + return weights_enum.from_str(value_name) diff --git a/torchvision/prototype/models/alexnet.py b/torchvision/prototype/models/alexnet.py index b45ca1e7085..28b0fa60504 100644 --- a/torchvision/prototype/models/alexnet.py +++ b/torchvision/prototype/models/alexnet.py @@ -25,8 +25,8 @@ class AlexNet_Weights(WeightsEnum): "acc@1": 56.522, "acc@5": 79.066, }, - default=True, ) + default = ImageNet1K_V1 def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet: diff --git a/torchvision/prototype/models/densenet.py b/torchvision/prototype/models/densenet.py index e779a2cd239..b8abbdde947 100644 --- a/torchvision/prototype/models/densenet.py +++ b/torchvision/prototype/models/densenet.py @@ -80,8 +80,8 @@ class DenseNet121_Weights(WeightsEnum): "acc@1": 74.434, "acc@5": 91.972, }, - default=True, ) + default = ImageNet1K_V1 class DenseNet161_Weights(WeightsEnum): @@ -93,8 +93,8 @@ class DenseNet161_Weights(WeightsEnum): "acc@1": 77.138, "acc@5": 93.560, }, - default=True, ) + default = ImageNet1K_V1 class DenseNet169_Weights(WeightsEnum): @@ -106,8 +106,8 @@ class DenseNet169_Weights(WeightsEnum): "acc@1": 75.600, "acc@5": 92.806, }, - default=True, ) + default = ImageNet1K_V1 class DenseNet201_Weights(WeightsEnum): @@ -119,8 +119,8 @@ class DenseNet201_Weights(WeightsEnum): "acc@1": 76.896, "acc@5": 93.370, }, - default=True, ) + default = ImageNet1K_V1 def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet: diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index c83aaf222fb..1f5c6461698 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -45,8 +45,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", "map": 37.0, }, - default=True, ) + default = Coco_V1 class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): @@ -58,8 +58,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", "map": 32.8, }, - default=True, ) + default = Coco_V1 class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): @@ -71,8 +71,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", "map": 22.8, }, - default=True, ) + default = Coco_V1 def fasterrcnn_resnet50_fpn( diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py index 85250ac2a33..a811999681d 100644 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ b/torchvision/prototype/models/detection/keypoint_rcnn.py @@ -35,7 +35,6 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): "box_map": 50.6, "kp_map": 61.1, }, - default=False, ) Coco_V1 = Weights( url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", @@ -46,8 +45,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): "box_map": 54.6, "kp_map": 65.0, }, - default=True, ) + default = Coco_V1 def keypointrcnn_resnet50_fpn( diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index ea7ab4f5fc7..4eb285fac0d 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -34,8 +34,8 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): "box_map": 37.9, "mask_map": 34.6, }, - default=True, ) + default = Coco_V1 def maskrcnn_resnet50_fpn( diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index d442c79d5b6..799bc21c379 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -34,8 +34,8 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", "map": 36.4, }, - default=True, ) + default = Coco_V1 def retinanet_resnet50_fpn( diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py index 37f5c2a6944..f57b47c00d6 100644 --- a/torchvision/prototype/models/detection/ssd.py +++ b/torchvision/prototype/models/detection/ssd.py @@ -33,8 +33,8 @@ class SSD300_VGG16_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", "map": 25.1, }, - default=True, ) + default = Coco_V1 def ssd300_vgg16( diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 309362f2f11..4a61c50101a 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -38,8 +38,8 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", "map": 21.3, }, - default=True, ) + default = Coco_V1 def ssdlite320_mobilenet_v3_large( diff --git a/torchvision/prototype/models/efficientnet.py b/torchvision/prototype/models/efficientnet.py index 74ca6ccc71d..f4a69aac70c 100644 --- a/torchvision/prototype/models/efficientnet.py +++ b/torchvision/prototype/models/efficientnet.py @@ -79,8 +79,8 @@ class EfficientNet_B0_Weights(WeightsEnum): "acc@1": 77.692, "acc@5": 93.532, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B1_Weights(WeightsEnum): @@ -93,8 +93,8 @@ class EfficientNet_B1_Weights(WeightsEnum): "acc@1": 78.642, "acc@5": 94.186, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B2_Weights(WeightsEnum): @@ -107,8 +107,8 @@ class EfficientNet_B2_Weights(WeightsEnum): "acc@1": 80.608, "acc@5": 95.310, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B3_Weights(WeightsEnum): @@ -121,8 +121,8 @@ class EfficientNet_B3_Weights(WeightsEnum): "acc@1": 82.008, "acc@5": 96.054, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B4_Weights(WeightsEnum): @@ -135,8 +135,8 @@ class EfficientNet_B4_Weights(WeightsEnum): "acc@1": 83.384, "acc@5": 96.594, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B5_Weights(WeightsEnum): @@ -149,8 +149,8 @@ class EfficientNet_B5_Weights(WeightsEnum): "acc@1": 83.444, "acc@5": 96.628, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B6_Weights(WeightsEnum): @@ -163,8 +163,8 @@ class EfficientNet_B6_Weights(WeightsEnum): "acc@1": 84.008, "acc@5": 96.916, }, - default=True, ) + default = ImageNet1K_V1 class EfficientNet_B7_Weights(WeightsEnum): @@ -177,8 +177,8 @@ class EfficientNet_B7_Weights(WeightsEnum): "acc@1": 84.122, "acc@5": 96.908, }, - default=True, ) + default = ImageNet1K_V1 def efficientnet_b0( diff --git a/torchvision/prototype/models/googlenet.py b/torchvision/prototype/models/googlenet.py index 352c49d1a2e..f62c5a96e15 100644 --- a/torchvision/prototype/models/googlenet.py +++ b/torchvision/prototype/models/googlenet.py @@ -26,8 +26,8 @@ class GoogLeNet_Weights(WeightsEnum): "acc@1": 69.778, "acc@5": 89.530, }, - default=True, ) + default = ImageNet1K_V1 def googlenet(weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet: diff --git a/torchvision/prototype/models/inception.py b/torchvision/prototype/models/inception.py index 9837b1fc4a6..4814fa76c5c 100644 --- a/torchvision/prototype/models/inception.py +++ b/torchvision/prototype/models/inception.py @@ -25,8 +25,8 @@ class Inception_V3_Weights(WeightsEnum): "acc@1": 77.294, "acc@5": 93.450, }, - default=True, ) + default = ImageNet1K_V1 def inception_v3(weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3: diff --git a/torchvision/prototype/models/mnasnet.py b/torchvision/prototype/models/mnasnet.py index 73aaea0beca..554057a9ba1 100644 --- a/torchvision/prototype/models/mnasnet.py +++ b/torchvision/prototype/models/mnasnet.py @@ -40,8 +40,8 @@ class MNASNet0_5_Weights(WeightsEnum): "acc@1": 67.734, "acc@5": 87.490, }, - default=True, ) + default = ImageNet1K_V1 class MNASNet0_75_Weights(WeightsEnum): @@ -58,8 +58,8 @@ class MNASNet1_0_Weights(WeightsEnum): "acc@1": 73.456, "acc@5": 91.510, }, - default=True, ) + default = ImageNet1K_V1 class MNASNet1_3_Weights(WeightsEnum): diff --git a/torchvision/prototype/models/mobilenetv2.py b/torchvision/prototype/models/mobilenetv2.py index 0c0f80d081a..64c7221da6d 100644 --- a/torchvision/prototype/models/mobilenetv2.py +++ b/torchvision/prototype/models/mobilenetv2.py @@ -25,8 +25,8 @@ class MobileNet_V2_Weights(WeightsEnum): "acc@1": 71.878, "acc@5": 90.286, }, - default=True, ) + default = ImageNet1K_V1 def mobilenet_v2(weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2: diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py index e014fb5acb2..a92c7667aab 100644 --- a/torchvision/prototype/models/mobilenetv3.py +++ b/torchvision/prototype/models/mobilenetv3.py @@ -54,7 +54,6 @@ class MobileNet_V3_Large_Weights(WeightsEnum): "acc@1": 74.042, "acc@5": 91.340, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", @@ -65,8 +64,8 @@ class MobileNet_V3_Large_Weights(WeightsEnum): "acc@1": 75.274, "acc@5": 92.566, }, - default=True, ) + default = ImageNet1K_V2 class MobileNet_V3_Small_Weights(WeightsEnum): @@ -79,8 +78,8 @@ class MobileNet_V3_Small_Weights(WeightsEnum): "acc@1": 67.668, "acc@5": 87.402, }, - default=True, ) + default = ImageNet1K_V1 def mobilenet_v3_large( diff --git a/torchvision/prototype/models/quantization/googlenet.py b/torchvision/prototype/models/quantization/googlenet.py index 3d26fd7d607..dc3c875b79a 100644 --- a/torchvision/prototype/models/quantization/googlenet.py +++ b/torchvision/prototype/models/quantization/googlenet.py @@ -38,8 +38,8 @@ class GoogLeNet_QuantizedWeights(WeightsEnum): "acc@1": 69.826, "acc@5": 89.404, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 def googlenet( diff --git a/torchvision/prototype/models/quantization/inception.py b/torchvision/prototype/models/quantization/inception.py index ff779076df6..d1d5d4ca8fe 100644 --- a/torchvision/prototype/models/quantization/inception.py +++ b/torchvision/prototype/models/quantization/inception.py @@ -37,8 +37,8 @@ class Inception_V3_QuantizedWeights(WeightsEnum): "acc@1": 77.176, "acc@5": 93.354, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 def inception_v3( diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py index c5afd731fad..81540f2f840 100644 --- a/torchvision/prototype/models/quantization/mobilenetv2.py +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -38,8 +38,8 @@ class MobileNet_V2_QuantizedWeights(WeightsEnum): "acc@1": 71.658, "acc@5": 90.150, }, - default=True, ) + default = ImageNet1K_QNNPACK_V1 def mobilenet_v2( diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py index a29e3f44697..9d29484c18f 100644 --- a/torchvision/prototype/models/quantization/mobilenetv3.py +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -71,8 +71,8 @@ class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): "acc@1": 73.004, "acc@5": 90.858, }, - default=True, ) + default = ImageNet1K_QNNPACK_V1 def mobilenet_v3_large( diff --git a/torchvision/prototype/models/quantization/resnet.py b/torchvision/prototype/models/quantization/resnet.py index 0de4eb5557b..c6bd530f393 100644 --- a/torchvision/prototype/models/quantization/resnet.py +++ b/torchvision/prototype/models/quantization/resnet.py @@ -73,8 +73,8 @@ class ResNet18_QuantizedWeights(WeightsEnum): "acc@1": 69.494, "acc@5": 88.882, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 class ResNet50_QuantizedWeights(WeightsEnum): @@ -87,7 +87,6 @@ class ResNet50_QuantizedWeights(WeightsEnum): "acc@1": 75.920, "acc@5": 92.814, }, - default=False, ) ImageNet1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", @@ -98,8 +97,8 @@ class ResNet50_QuantizedWeights(WeightsEnum): "acc@1": 80.282, "acc@5": 94.976, }, - default=True, ) + default = ImageNet1K_FBGEMM_V2 class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): @@ -112,7 +111,6 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): "acc@1": 78.986, "acc@5": 94.480, }, - default=False, ) ImageNet1K_FBGEMM_V2 = Weights( url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", @@ -123,8 +121,8 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): "acc@1": 82.574, "acc@5": 96.132, }, - default=True, ) + default = ImageNet1K_FBGEMM_V2 def resnet18( diff --git a/torchvision/prototype/models/quantization/shufflenetv2.py b/torchvision/prototype/models/quantization/shufflenetv2.py index 6677983a1d9..111763f2614 100644 --- a/torchvision/prototype/models/quantization/shufflenetv2.py +++ b/torchvision/prototype/models/quantization/shufflenetv2.py @@ -69,8 +69,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): "acc@1": 57.972, "acc@5": 79.780, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): @@ -83,8 +83,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): "acc@1": 68.360, "acc@5": 87.582, }, - default=True, ) + default = ImageNet1K_FBGEMM_V1 def shufflenet_v2_x0_5( diff --git a/torchvision/prototype/models/regnet.py b/torchvision/prototype/models/regnet.py index 1e12ae7bbd2..d810a0d1300 100644 --- a/torchvision/prototype/models/regnet.py +++ b/torchvision/prototype/models/regnet.py @@ -74,8 +74,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum): "acc@1": 74.046, "acc@5": 91.716, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_800MF_Weights(WeightsEnum): @@ -88,8 +88,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum): "acc@1": 76.420, "acc@5": 93.136, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_1_6GF_Weights(WeightsEnum): @@ -102,8 +102,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): "acc@1": 77.950, "acc@5": 93.966, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_3_2GF_Weights(WeightsEnum): @@ -116,8 +116,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): "acc@1": 78.948, "acc@5": 94.576, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_8GF_Weights(WeightsEnum): @@ -130,8 +130,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum): "acc@1": 80.032, "acc@5": 95.048, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_16GF_Weights(WeightsEnum): @@ -144,8 +144,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum): "acc@1": 80.424, "acc@5": 95.240, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_Y_32GF_Weights(WeightsEnum): @@ -158,8 +158,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum): "acc@1": 80.878, "acc@5": 95.340, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_400MF_Weights(WeightsEnum): @@ -172,8 +172,8 @@ class RegNet_X_400MF_Weights(WeightsEnum): "acc@1": 72.834, "acc@5": 90.950, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_800MF_Weights(WeightsEnum): @@ -186,8 +186,8 @@ class RegNet_X_800MF_Weights(WeightsEnum): "acc@1": 75.212, "acc@5": 92.348, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_1_6GF_Weights(WeightsEnum): @@ -200,8 +200,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): "acc@1": 77.040, "acc@5": 93.440, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_3_2GF_Weights(WeightsEnum): @@ -214,8 +214,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): "acc@1": 78.364, "acc@5": 93.992, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_8GF_Weights(WeightsEnum): @@ -228,8 +228,8 @@ class RegNet_X_8GF_Weights(WeightsEnum): "acc@1": 79.344, "acc@5": 94.686, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_16GF_Weights(WeightsEnum): @@ -242,8 +242,8 @@ class RegNet_X_16GF_Weights(WeightsEnum): "acc@1": 80.058, "acc@5": 94.944, }, - default=True, ) + default = ImageNet1K_V1 class RegNet_X_32GF_Weights(WeightsEnum): @@ -256,8 +256,8 @@ class RegNet_X_32GF_Weights(WeightsEnum): "acc@1": 80.622, "acc@5": 95.248, }, - default=True, ) + default = ImageNet1K_V1 def regnet_y_400mf(weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet: diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py index e213864acbe..3c68f0a430c 100644 --- a/torchvision/prototype/models/resnet.py +++ b/torchvision/prototype/models/resnet.py @@ -64,8 +64,8 @@ class ResNet18_Weights(WeightsEnum): "acc@1": 69.758, "acc@5": 89.078, }, - default=True, ) + default = ImageNet1K_V1 class ResNet34_Weights(WeightsEnum): @@ -78,8 +78,8 @@ class ResNet34_Weights(WeightsEnum): "acc@1": 73.314, "acc@5": 91.420, }, - default=True, ) + default = ImageNet1K_V1 class ResNet50_Weights(WeightsEnum): @@ -92,7 +92,6 @@ class ResNet50_Weights(WeightsEnum): "acc@1": 76.130, "acc@5": 92.862, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet50-f46c3f97.pth", @@ -103,8 +102,8 @@ class ResNet50_Weights(WeightsEnum): "acc@1": 80.674, "acc@5": 95.166, }, - default=True, ) + default = ImageNet1K_V2 class ResNet101_Weights(WeightsEnum): @@ -117,7 +116,6 @@ class ResNet101_Weights(WeightsEnum): "acc@1": 77.374, "acc@5": 93.546, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", @@ -128,8 +126,8 @@ class ResNet101_Weights(WeightsEnum): "acc@1": 81.886, "acc@5": 95.780, }, - default=True, ) + default = ImageNet1K_V2 class ResNet152_Weights(WeightsEnum): @@ -142,7 +140,6 @@ class ResNet152_Weights(WeightsEnum): "acc@1": 78.312, "acc@5": 94.046, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnet152-f82ba261.pth", @@ -153,8 +150,8 @@ class ResNet152_Weights(WeightsEnum): "acc@1": 82.284, "acc@5": 96.002, }, - default=True, ) + default = ImageNet1K_V2 class ResNeXt50_32X4D_Weights(WeightsEnum): @@ -167,7 +164,6 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): "acc@1": 77.618, "acc@5": 93.698, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", @@ -178,8 +174,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): "acc@1": 81.198, "acc@5": 95.340, }, - default=True, ) + default = ImageNet1K_V2 class ResNeXt101_32X8D_Weights(WeightsEnum): @@ -192,7 +188,6 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): "acc@1": 79.312, "acc@5": 94.526, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", @@ -203,8 +198,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): "acc@1": 82.834, "acc@5": 96.228, }, - default=True, ) + default = ImageNet1K_V2 class Wide_ResNet50_2_Weights(WeightsEnum): @@ -217,7 +212,6 @@ class Wide_ResNet50_2_Weights(WeightsEnum): "acc@1": 78.468, "acc@5": 94.086, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", @@ -228,8 +222,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum): "acc@1": 81.602, "acc@5": 95.758, }, - default=True, ) + default = ImageNet1K_V2 class Wide_ResNet101_2_Weights(WeightsEnum): @@ -242,7 +236,6 @@ class Wide_ResNet101_2_Weights(WeightsEnum): "acc@1": 78.848, "acc@5": 94.284, }, - default=False, ) ImageNet1K_V2 = Weights( url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", @@ -253,8 +246,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum): "acc@1": 82.510, "acc@5": 96.020, }, - default=True, ) + default = ImageNet1K_V2 def resnet18(weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 638b771c333..30c90013c9b 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -40,8 +40,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): "mIoU": 66.4, "acc": 92.4, }, - default=True, ) + default = CocoWithVocLabels_V1 class DeepLabV3_ResNet101_Weights(WeightsEnum): @@ -54,8 +54,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): "mIoU": 67.4, "acc": 92.4, }, - default=True, ) + default = CocoWithVocLabels_V1 class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): @@ -68,8 +68,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): "mIoU": 60.3, "acc": 91.2, }, - default=True, ) + default = CocoWithVocLabels_V1 def deeplabv3_resnet50( diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 841e2ea95c5..42d15a0c3cf 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -30,8 +30,8 @@ class FCN_ResNet50_Weights(WeightsEnum): "mIoU": 60.5, "acc": 91.4, }, - default=True, ) + default = CocoWithVocLabels_V1 class FCN_ResNet101_Weights(WeightsEnum): @@ -44,8 +44,8 @@ class FCN_ResNet101_Weights(WeightsEnum): "mIoU": 63.7, "acc": 91.9, }, - default=True, ) + default = CocoWithVocLabels_V1 def fcn_resnet50( diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 9743e02fa16..f80e1079c87 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -25,8 +25,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): "mIoU": 57.9, "acc": 91.2, }, - default=True, ) + default = CocoWithVocLabels_V1 def lraspp_mobilenet_v3_large( diff --git a/torchvision/prototype/models/shufflenetv2.py b/torchvision/prototype/models/shufflenetv2.py index 9fa98c44223..a8857c2996e 100644 --- a/torchvision/prototype/models/shufflenetv2.py +++ b/torchvision/prototype/models/shufflenetv2.py @@ -57,8 +57,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): "acc@1": 69.362, "acc@5": 88.316, }, - default=True, ) + default = ImageNet1K_V1 class ShuffleNet_V2_X1_0_Weights(WeightsEnum): @@ -70,8 +70,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum): "acc@1": 60.552, "acc@5": 81.746, }, - default=True, ) + default = ImageNet1K_V1 class ShuffleNet_V2_X1_5_Weights(WeightsEnum): diff --git a/torchvision/prototype/models/squeezenet.py b/torchvision/prototype/models/squeezenet.py index fdfaa01e8be..77c9a1629d4 100644 --- a/torchvision/prototype/models/squeezenet.py +++ b/torchvision/prototype/models/squeezenet.py @@ -30,8 +30,8 @@ class SqueezeNet1_0_Weights(WeightsEnum): "acc@1": 58.092, "acc@5": 80.420, }, - default=True, ) + default = ImageNet1K_V1 class SqueezeNet1_1_Weights(WeightsEnum): @@ -43,8 +43,8 @@ class SqueezeNet1_1_Weights(WeightsEnum): "acc@1": 58.178, "acc@5": 80.624, }, - default=True, ) + default = ImageNet1K_V1 def squeezenet1_0(weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> SqueezeNet: diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index a357426693d..708608826e0 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -57,8 +57,8 @@ class VGG11_Weights(WeightsEnum): "acc@1": 69.020, "acc@5": 88.628, }, - default=True, ) + default = ImageNet1K_V1 class VGG11_BN_Weights(WeightsEnum): @@ -70,8 +70,8 @@ class VGG11_BN_Weights(WeightsEnum): "acc@1": 70.370, "acc@5": 89.810, }, - default=True, ) + default = ImageNet1K_V1 class VGG13_Weights(WeightsEnum): @@ -83,8 +83,8 @@ class VGG13_Weights(WeightsEnum): "acc@1": 69.928, "acc@5": 89.246, }, - default=True, ) + default = ImageNet1K_V1 class VGG13_BN_Weights(WeightsEnum): @@ -96,8 +96,8 @@ class VGG13_BN_Weights(WeightsEnum): "acc@1": 71.586, "acc@5": 90.374, }, - default=True, ) + default = ImageNet1K_V1 class VGG16_Weights(WeightsEnum): @@ -109,7 +109,6 @@ class VGG16_Weights(WeightsEnum): "acc@1": 71.592, "acc@5": 90.382, }, - default=True, ) # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # same input standardization method as the paper. Only the `features` weights have proper values, those on the @@ -127,8 +126,8 @@ class VGG16_Weights(WeightsEnum): "acc@1": float("nan"), "acc@5": float("nan"), }, - default=False, ) + default = ImageNet1K_V1 class VGG16_BN_Weights(WeightsEnum): @@ -140,8 +139,8 @@ class VGG16_BN_Weights(WeightsEnum): "acc@1": 73.360, "acc@5": 91.516, }, - default=True, ) + default = ImageNet1K_V1 class VGG19_Weights(WeightsEnum): @@ -153,8 +152,8 @@ class VGG19_Weights(WeightsEnum): "acc@1": 72.376, "acc@5": 90.876, }, - default=True, ) + default = ImageNet1K_V1 class VGG19_BN_Weights(WeightsEnum): @@ -166,8 +165,8 @@ class VGG19_BN_Weights(WeightsEnum): "acc@1": 74.218, "acc@5": 91.842, }, - default=True, ) + default = ImageNet1K_V1 def vgg11(weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index c75f618a8b1..48c4293f0e1 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -68,8 +68,8 @@ class R3D_18_Weights(WeightsEnum): "acc@1": 52.75, "acc@5": 75.45, }, - default=True, ) + default = Kinetics400_V1 class MC3_18_Weights(WeightsEnum): @@ -81,8 +81,8 @@ class MC3_18_Weights(WeightsEnum): "acc@1": 53.90, "acc@5": 76.29, }, - default=True, ) + default = Kinetics400_V1 class R2Plus1D_18_Weights(WeightsEnum): @@ -94,8 +94,8 @@ class R2Plus1D_18_Weights(WeightsEnum): "acc@1": 57.50, "acc@5": 78.81, }, - default=True, ) + default = Kinetics400_V1 def r3d_18(weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet: