Skip to content

Rename prototype weight names to comply with PEP8 #5257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 24, 2022
24 changes: 14 additions & 10 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ def _build_model(fn, **kwargs):
@pytest.mark.parametrize(
"name, weight",
[
("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2),
("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
(
"ResNet50_QuantizedWeights.default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
"ResNet50_QuantizedWeights.DEFAULT",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
),
(
"ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
"ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
),
],
)
Expand All @@ -83,7 +83,7 @@ def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn)
print(weights_enum)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -117,25 +117,29 @@ def test_schema_meta_validation(model_fn):

problematic_weights = {}
incorrect_params = []
bad_names = []
for w in weights_enum:
missing_fields = fields - set(w.meta.keys())
if missing_fields:
problematic_weights[w] = missing_fields
if w == weights_enum.default:
if w == weights_enum.DEFAULT:
if module_name == "quantization":
# parametes() cound doesn't work well with quantization, so we check against the non-quantized
# parameters() count doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)
else:
if w.meta.get("num_params") != weights_enum.default.meta.get("num_params"):
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
incorrect_params.append(w)
if not w.name.isupper():
bad_names.append(w)

assert not problematic_weights
assert not incorrect_params
assert not bad_names


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __getattr__(self, name):

def get_weight(name: str) -> WeightsEnum:
"""
Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1"
Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"

Args:
name (str): The name of the weight enum entry.
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def inner_wrapper(*args: Any, **kwargs: Any) -> M:
)
if pretrained_arg:
msg = (
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` "
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
f"to get the most up-to-date weights."
)
warnings.warn(msg)
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class AlexNet_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -31,10 +31,10 @@ class AlexNet_Weights(WeightsEnum):
"acc@5": 79.066,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
weights = AlexNet_Weights.verify(weights)

Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def forward(self, x: Tensor) -> Tensor:


class ConvNeXt_Tiny_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
meta={
Expand All @@ -195,10 +195,10 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
"acc@5": 96.146,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Expand Down
24 changes: 12 additions & 12 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _densenet(


class DenseNet121_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -86,11 +86,11 @@ class DenseNet121_Weights(WeightsEnum):
"acc@5": 91.972,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


class DenseNet161_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -100,11 +100,11 @@ class DenseNet161_Weights(WeightsEnum):
"acc@5": 93.560,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


class DenseNet169_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -114,11 +114,11 @@ class DenseNet169_Weights(WeightsEnum):
"acc@5": 92.806,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


class DenseNet201_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -128,31 +128,31 @@ class DenseNet201_Weights(WeightsEnum):
"acc@5": 93.370,
},
)
default = ImageNet1K_V1
DEFAULT = IMAGENET1K_V1


@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet121_Weights.verify(weights)

return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet161_Weights.verify(weights)

return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet169_Weights.verify(weights)

return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1))
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
weights = DenseNet201_Weights.verify(weights)

Expand Down
26 changes: 13 additions & 13 deletions torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
Expand All @@ -50,11 +50,11 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map": 37.0,
},
)
default = Coco_V1
DEFAULT = COCO_V1


class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
Expand All @@ -64,11 +64,11 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
"map": 32.8,
},
)
default = Coco_V1
DEFAULT = COCO_V1


class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
Expand All @@ -78,12 +78,12 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
"map": 22.8,
},
)
default = Coco_V1
DEFAULT = COCO_V1


@handle_legacy_interface(
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fasterrcnn_resnet50_fpn(
*,
Expand Down Expand Up @@ -113,7 +113,7 @@ def fasterrcnn_resnet50_fpn(

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == FasterRCNN_ResNet50_FPN_Weights.Coco_V1:
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

return model
Expand Down Expand Up @@ -161,8 +161,8 @@ def _fasterrcnn_mobilenet_v3_large_fpn(


@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_fpn(
*,
Expand Down Expand Up @@ -192,8 +192,8 @@ def fasterrcnn_mobilenet_v3_large_fpn(


@handle_legacy_interface(
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def fasterrcnn_mobilenet_v3_large_320_fpn(
*,
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
"map": 39.2,
},
)
default = COCO_V1
DEFAULT = COCO_V1


@handle_legacy_interface(
weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def fcos_resnet50_fpn(
*,
Expand Down
14 changes: 7 additions & 7 deletions torchvision/prototype/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_Legacy = Weights(
COCO_LEGACY = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval,
meta={
Expand All @@ -45,7 +45,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map_kp": 61.1,
},
)
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval,
meta={
Expand All @@ -56,17 +56,17 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map_kp": 65.0,
},
)
default = Coco_V1
DEFAULT = COCO_V1


@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
if kwargs["pretrained"] == "legacy"
else KeypointRCNN_ResNet50_FPN_Weights.Coco_V1,
else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def keypointrcnn_resnet50_fpn(
*,
Expand Down Expand Up @@ -101,7 +101,7 @@ def keypointrcnn_resnet50_fpn(

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == KeypointRCNN_ResNet50_FPN_Weights.Coco_V1:
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

return model
10 changes: 5 additions & 5 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
Coco_V1 = Weights(
COCO_V1 = Weights(
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval,
meta={
Expand All @@ -39,12 +39,12 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
"map_mask": 34.6,
},
)
default = Coco_V1
DEFAULT = COCO_V1


@handle_legacy_interface(
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.Coco_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
def maskrcnn_resnet50_fpn(
*,
Expand Down Expand Up @@ -74,7 +74,7 @@ def maskrcnn_resnet50_fpn(

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
if weights == MaskRCNN_ResNet50_FPN_Weights.Coco_V1:
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)

return model
Loading