Skip to content
Merged
24 changes: 14 additions & 10 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ def _build_model(fn, **kwargs):
@pytest.mark.parametrize(
"name, weight",
[
("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2),
("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
(
"ResNet50_QuantizedWeights.default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
"ResNet50_QuantizedWeights.DEFAULT",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
),
(
"ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
"ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
),
],
)
Expand All @@ -83,7 +83,7 @@ def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn)
print(weights_enum)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -117,25 +117,29 @@ def test_schema_meta_validation(model_fn):

problematic_weights = {}
incorrect_params = []
bad_names = []
for w in weights_enum:
missing_fields = fields - set(w.meta.keys())
if missing_fields:
problematic_weights[w] = missing_fields
if w == weights_enum.default:
if w == weights_enum.DEFAULT:
if module_name == "quantization":
# parametes() cound doesn't work well with quantization, so we check against the non-quantized
# parameters() count doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != weights_enum.default.meta.get("num_params"):
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
incorrect_params.append(w)
if not w.name.isupper():
bad_names.append(w)

assert not problematic_weights
assert not incorrect_params
assert not bad_names


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __getattr__(self, name):

def get_weight(name: str) -> WeightsEnum:
"""
Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1"
Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"

Args:
name (str): The name of the weight enum entry.
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def inner_wrapper(*args: Any, **kwargs: Any) -> M:
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` "
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
f"to get the most up-to-date weights."
)
warnings.warn(msg)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class AlexNet_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -31,10 +31,10 @@ class AlexNet_Weights(WeightsEnum):
"acc@5": 79.066,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
weights = AlexNet_Weights.verify(weights)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def forward(self, x: Tensor) -> Tensor:


class ConvNeXt_Tiny_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
meta={
Expand All @@ -195,10 +195,10 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
"acc@5": 96.146,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Expand Down
24 changes: 12 additions & 12 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _densenet(


class DenseNet121_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -86,11 +86,11 @@ class DenseNet121_Weights(WeightsEnum):
"acc@5": 91.972,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


class DenseNet161_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -100,11 +100,11 @@ class DenseNet161_Weights(WeightsEnum):
"acc@5": 93.560,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


class DenseNet169_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -114,11 +114,11 @@ class DenseNet169_Weights(WeightsEnum):
"acc@5": 92.806,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


class DenseNet201_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -128,31 +128,31 @@ class DenseNet201_Weights(WeightsEnum):
"acc@5": 93.370,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet121_Weights.verify(weights)

return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet161_Weights.verify(weights)

return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet169_Weights.verify(weights)

return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet201_Weights.verify(weights)

Expand Down
26 changes: 13 additions & 13 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
Expand All @@ -50,11 +50,11 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map": 37.0,
},
)
default = Coco_V1
DEFAULT = COCO_V1


class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
Expand All @@ -64,11 +64,11 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
"map": 32.8,
},
)
default = Coco_V1
DEFAULT = COCO_V1


class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
Expand All @@ -78,12 +78,12 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
"map": 22.8,
},
)
default = Coco_V1
DEFAULT = COCO_V1


@handle_legacy_interface(
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fasterrcnn_resnet50_fpn(
*,
Expand Down Expand Up @@ -113,7 +113,7 @@ def fasterrcnn_resnet50_fpn(

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNN_ResNet50_FPN_Weights.Coco_V1:
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

return model
Expand Down Expand Up @@ -161,8 +161,8 @@ def _fasterrcnn_mobilenet_v3_large_fpn(


@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_fpn(
*,
Expand Down Expand Up @@ -192,8 +192,8 @@ def fasterrcnn_mobilenet_v3_large_fpn(


@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_320_fpn(
*,
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
"map": 39.2,
},
)
default = COCO_V1
DEFAULT = COCO_V1


@handle_legacy_interface(
weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fcos_resnet50_fpn(
*,
Expand Down
14 changes: 7 additions & 7 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_Legacy = Weights(
COCO_LEGACY = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval,
meta={
Expand All @@ -45,7 +45,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map_kp": 61.1,
},
)
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval,
meta={
Expand All @@ -56,17 +56,17 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map_kp": 65.0,
},
)
default = Coco_V1
DEFAULT = COCO_V1


@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
if kwargs["pretrained"] == "legacy"
else KeypointRCNN_ResNet50_FPN_Weights.Coco_V1,
else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def keypointrcnn_resnet50_fpn(
*,
Expand Down Expand Up @@ -101,7 +101,7 @@ def keypointrcnn_resnet50_fpn(

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNN_ResNet50_FPN_Weights.Coco_V1:
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

return model
10 changes: 5 additions & 5 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval,
meta={
Expand All @@ -39,12 +39,12 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map_mask": 34.6,
},
)
default = Coco_V1
DEFAULT = COCO_V1


@handle_legacy_interface(
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def maskrcnn_resnet50_fpn(
*,
Expand Down Expand Up @@ -74,7 +74,7 @@ def maskrcnn_resnet50_fpn(

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNN_ResNet50_FPN_Weights.Coco_V1:
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

return model
Loading