Skip to content

Commit bf8efe6

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Rename prototype weight names to comply with PEP8 (#5257)
Summary: * renamed ImageNet weights * renamed COCO weights * renamed COCO with VOC labels weights * renamed Kinetics 400 weights * rename default with DEFAULT * update test * fix typos * update test * update test * update test * indent as w was weight_enum * revert * Adding back the capitalization test Reviewed By: jdsgomes, prabhat00155 Differential Revision: D33739379 fbshipit-source-id: cec33c6bd6be59f4f44bd2adc9906eaa9e1c6696 Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 2f18f54 commit bf8efe6

36 files changed

+344
-340
lines changed

test/test_prototype_models.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,15 @@ def _build_model(fn, **kwargs):
5454
@pytest.mark.parametrize(
5555
"name, weight",
5656
[
57-
("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
58-
("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2),
57+
("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
58+
("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
5959
(
60-
"ResNet50_QuantizedWeights.default",
61-
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
60+
"ResNet50_QuantizedWeights.DEFAULT",
61+
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
6262
),
6363
(
64-
"ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1",
65-
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
64+
"ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
65+
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
6666
),
6767
],
6868
)
@@ -83,7 +83,7 @@ def test_naming_conventions(model_fn):
8383
weights_enum = _get_model_weights(model_fn)
8484
print(weights_enum)
8585
assert weights_enum is not None
86-
assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
86+
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
8787

8888

8989
@pytest.mark.parametrize(
@@ -117,25 +117,29 @@ def test_schema_meta_validation(model_fn):
117117

118118
problematic_weights = {}
119119
incorrect_params = []
120+
bad_names = []
120121
for w in weights_enum:
121122
missing_fields = fields - set(w.meta.keys())
122123
if missing_fields:
123124
problematic_weights[w] = missing_fields
124-
if w == weights_enum.default:
125+
if w == weights_enum.DEFAULT:
125126
if module_name == "quantization":
126-
# parametes() cound doesn't work well with quantization, so we check against the non-quantized
127+
# parameters() count doesn't work well with quantization, so we check against the non-quantized
127128
unquantized_w = w.meta.get("unquantized")
128129
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
129130
incorrect_params.append(w)
130131
else:
131132
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
132133
incorrect_params.append(w)
133134
else:
134-
if w.meta.get("num_params") != weights_enum.default.meta.get("num_params"):
135+
if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
135136
incorrect_params.append(w)
137+
if not w.name.isupper():
138+
bad_names.append(w)
136139

137140
assert not problematic_weights
138141
assert not incorrect_params
142+
assert not bad_names
139143

140144

141145
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))

torchvision/prototype/models/_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __getattr__(self, name):
8080

8181
def get_weight(name: str) -> WeightsEnum:
8282
"""
83-
Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1"
83+
Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
8484
8585
Args:
8686
name (str): The name of the weight enum entry.

torchvision/prototype/models/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def inner_wrapper(*args: Any, **kwargs: Any) -> M:
7878
)
7979
if pretrained_arg:
8080
msg = (
81-
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.default` "
81+
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` "
8282
f"to get the most up-to-date weights."
8383
)
8484
warnings.warn(msg)

torchvision/prototype/models/alexnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
class AlexNet_Weights(WeightsEnum):
17-
ImageNet1K_V1 = Weights(
17+
IMAGENET1K_V1 = Weights(
1818
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
1919
transforms=partial(ImageNetEval, crop_size=224),
2020
meta={
@@ -31,10 +31,10 @@ class AlexNet_Weights(WeightsEnum):
3131
"acc@5": 79.066,
3232
},
3333
)
34-
default = ImageNet1K_V1
34+
DEFAULT = IMAGENET1K_V1
3535

3636

37-
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.ImageNet1K_V1))
37+
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
3838
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
3939
weights = AlexNet_Weights.verify(weights)
4040

torchvision/prototype/models/convnext.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def forward(self, x: Tensor) -> Tensor:
178178

179179

180180
class ConvNeXt_Tiny_Weights(WeightsEnum):
181-
ImageNet1K_V1 = Weights(
181+
IMAGENET1K_V1 = Weights(
182182
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth",
183183
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
184184
meta={
@@ -195,10 +195,10 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
195195
"acc@5": 96.146,
196196
},
197197
)
198-
default = ImageNet1K_V1
198+
DEFAULT = IMAGENET1K_V1
199199

200200

201-
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1))
201+
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
202202
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
203203
r"""ConvNeXt model architecture from the
204204
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.

torchvision/prototype/models/densenet.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _densenet(
7676

7777

7878
class DenseNet121_Weights(WeightsEnum):
79-
ImageNet1K_V1 = Weights(
79+
IMAGENET1K_V1 = Weights(
8080
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
8181
transforms=partial(ImageNetEval, crop_size=224),
8282
meta={
@@ -86,11 +86,11 @@ class DenseNet121_Weights(WeightsEnum):
8686
"acc@5": 91.972,
8787
},
8888
)
89-
default = ImageNet1K_V1
89+
DEFAULT = IMAGENET1K_V1
9090

9191

9292
class DenseNet161_Weights(WeightsEnum):
93-
ImageNet1K_V1 = Weights(
93+
IMAGENET1K_V1 = Weights(
9494
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
9595
transforms=partial(ImageNetEval, crop_size=224),
9696
meta={
@@ -100,11 +100,11 @@ class DenseNet161_Weights(WeightsEnum):
100100
"acc@5": 93.560,
101101
},
102102
)
103-
default = ImageNet1K_V1
103+
DEFAULT = IMAGENET1K_V1
104104

105105

106106
class DenseNet169_Weights(WeightsEnum):
107-
ImageNet1K_V1 = Weights(
107+
IMAGENET1K_V1 = Weights(
108108
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
109109
transforms=partial(ImageNetEval, crop_size=224),
110110
meta={
@@ -114,11 +114,11 @@ class DenseNet169_Weights(WeightsEnum):
114114
"acc@5": 92.806,
115115
},
116116
)
117-
default = ImageNet1K_V1
117+
DEFAULT = IMAGENET1K_V1
118118

119119

120120
class DenseNet201_Weights(WeightsEnum):
121-
ImageNet1K_V1 = Weights(
121+
IMAGENET1K_V1 = Weights(
122122
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
123123
transforms=partial(ImageNetEval, crop_size=224),
124124
meta={
@@ -128,31 +128,31 @@ class DenseNet201_Weights(WeightsEnum):
128128
"acc@5": 93.370,
129129
},
130130
)
131-
default = ImageNet1K_V1
131+
DEFAULT = IMAGENET1K_V1
132132

133133

134-
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.ImageNet1K_V1))
134+
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
135135
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
136136
weights = DenseNet121_Weights.verify(weights)
137137

138138
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
139139

140140

141-
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.ImageNet1K_V1))
141+
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
142142
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
143143
weights = DenseNet161_Weights.verify(weights)
144144

145145
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
146146

147147

148-
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.ImageNet1K_V1))
148+
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
149149
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
150150
weights = DenseNet169_Weights.verify(weights)
151151

152152
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
153153

154154

155-
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.ImageNet1K_V1))
155+
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
156156
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
157157
weights = DenseNet201_Weights.verify(weights)
158158

torchvision/prototype/models/detection/faster_rcnn.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141

4242
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
43-
Coco_V1 = Weights(
43+
COCO_V1 = Weights(
4444
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
4545
transforms=CocoEval,
4646
meta={
@@ -50,11 +50,11 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
5050
"map": 37.0,
5151
},
5252
)
53-
default = Coco_V1
53+
DEFAULT = COCO_V1
5454

5555

5656
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
57-
Coco_V1 = Weights(
57+
COCO_V1 = Weights(
5858
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
5959
transforms=CocoEval,
6060
meta={
@@ -64,11 +64,11 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
6464
"map": 32.8,
6565
},
6666
)
67-
default = Coco_V1
67+
DEFAULT = COCO_V1
6868

6969

7070
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
71-
Coco_V1 = Weights(
71+
COCO_V1 = Weights(
7272
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
7373
transforms=CocoEval,
7474
meta={
@@ -78,12 +78,12 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
7878
"map": 22.8,
7979
},
8080
)
81-
default = Coco_V1
81+
DEFAULT = COCO_V1
8282

8383

8484
@handle_legacy_interface(
85-
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.Coco_V1),
86-
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
85+
weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1),
86+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
8787
)
8888
def fasterrcnn_resnet50_fpn(
8989
*,
@@ -113,7 +113,7 @@ def fasterrcnn_resnet50_fpn(
113113

114114
if weights is not None:
115115
model.load_state_dict(weights.get_state_dict(progress=progress))
116-
if weights == FasterRCNN_ResNet50_FPN_Weights.Coco_V1:
116+
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
117117
overwrite_eps(model, 0.0)
118118

119119
return model
@@ -161,8 +161,8 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
161161

162162

163163
@handle_legacy_interface(
164-
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.Coco_V1),
165-
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
164+
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1),
165+
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
166166
)
167167
def fasterrcnn_mobilenet_v3_large_fpn(
168168
*,
@@ -192,8 +192,8 @@ def fasterrcnn_mobilenet_v3_large_fpn(
192192

193193

194194
@handle_legacy_interface(
195-
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.Coco_V1),
196-
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.ImageNet1K_V1),
195+
weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1),
196+
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
197197
)
198198
def fasterrcnn_mobilenet_v3_large_320_fpn(
199199
*,

torchvision/prototype/models/detection/fcos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
3838
"map": 39.2,
3939
},
4040
)
41-
default = COCO_V1
41+
DEFAULT = COCO_V1
4242

4343

4444
@handle_legacy_interface(
4545
weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1),
46-
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
46+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
4747
)
4848
def fcos_resnet50_fpn(
4949
*,

torchvision/prototype/models/detection/keypoint_rcnn.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
37-
Coco_Legacy = Weights(
37+
COCO_LEGACY = Weights(
3838
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
3939
transforms=CocoEval,
4040
meta={
@@ -45,7 +45,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
4545
"map_kp": 61.1,
4646
},
4747
)
48-
Coco_V1 = Weights(
48+
COCO_V1 = Weights(
4949
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
5050
transforms=CocoEval,
5151
meta={
@@ -56,17 +56,17 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
5656
"map_kp": 65.0,
5757
},
5858
)
59-
default = Coco_V1
59+
DEFAULT = COCO_V1
6060

6161

6262
@handle_legacy_interface(
6363
weights=(
6464
"pretrained",
65-
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.Coco_Legacy
65+
lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
6666
if kwargs["pretrained"] == "legacy"
67-
else KeypointRCNN_ResNet50_FPN_Weights.Coco_V1,
67+
else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
6868
),
69-
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
69+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
7070
)
7171
def keypointrcnn_resnet50_fpn(
7272
*,
@@ -101,7 +101,7 @@ def keypointrcnn_resnet50_fpn(
101101

102102
if weights is not None:
103103
model.load_state_dict(weights.get_state_dict(progress=progress))
104-
if weights == KeypointRCNN_ResNet50_FPN_Weights.Coco_V1:
104+
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
105105
overwrite_eps(model, 0.0)
106106

107107
return model

torchvision/prototype/models/detection/mask_rcnn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
27-
Coco_V1 = Weights(
27+
COCO_V1 = Weights(
2828
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
2929
transforms=CocoEval,
3030
meta={
@@ -39,12 +39,12 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
3939
"map_mask": 34.6,
4040
},
4141
)
42-
default = Coco_V1
42+
DEFAULT = COCO_V1
4343

4444

4545
@handle_legacy_interface(
46-
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.Coco_V1),
47-
weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1),
46+
weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1),
47+
weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
4848
)
4949
def maskrcnn_resnet50_fpn(
5050
*,
@@ -74,7 +74,7 @@ def maskrcnn_resnet50_fpn(
7474

7575
if weights is not None:
7676
model.load_state_dict(weights.get_state_dict(progress=progress))
77-
if weights == MaskRCNN_ResNet50_FPN_Weights.Coco_V1:
77+
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
7878
overwrite_eps(model, 0.0)
7979

8080
return model

0 commit comments

Comments
 (0)