Skip to content

Commit 1deb2ec

Browse files
authored
Cleanup Models prototype implementation (#4940)
* Disable WeightEntry to pass-through `Weights.verify()` * Rename `Weights.state_dict()` to `Weights.get_state_dict()` * Add TODO for missing doc. * Moving warning messages for googlenet. * Upper-case global `_COMMON_META` var * Replace argument with parameter in all warnings.
1 parent 30f4d10 commit 1deb2ec

30 files changed

+238
-238
lines changed

torchvision/prototype/models/_api.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def verify(cls, obj: Any) -> Any:
5050
if obj is not None:
5151
if type(obj) is str:
5252
obj = cls.from_str(obj)
53-
elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry):
53+
elif not isinstance(obj, cls):
5454
raise TypeError(
5555
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
5656
)
@@ -63,7 +63,7 @@ def from_str(cls, value: str) -> "Weights":
6363
return v
6464
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
6565

66-
def state_dict(self, progress: bool) -> OrderedDict:
66+
def get_state_dict(self, progress: bool) -> OrderedDict:
6767
return load_state_dict_from_url(self.url, progress=progress)
6868

6969
def __repr__(self):
@@ -90,7 +90,7 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
9090
"""
9191
sig = signature(fn)
9292
if "weights" not in sig.parameters:
93-
raise ValueError("The method is missing the 'weights' argument.")
93+
raise ValueError("The method is missing the 'weights' parameter.")
9494

9595
ann = signature(fn).parameters["weights"].annotation
9696
weights_class = None

torchvision/prototype/models/alexnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class AlexNetWeights(Weights):
3030

3131
def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
3232
if "pretrained" in kwargs:
33-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
33+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
3434
weights = AlexNetWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
3535
weights = AlexNetWeights.verify(weights)
3636
if weights is not None:
@@ -39,6 +39,6 @@ def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **k
3939
model = AlexNet(**kwargs)
4040

4141
if weights is not None:
42-
model.load_state_dict(weights.state_dict(progress=progress))
42+
model.load_state_dict(weights.get_state_dict(progress=progress))
4343

4444
return model

torchvision/prototype/models/densenet.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None
3434
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
3535
)
3636

37-
state_dict = weights.state_dict(progress=progress)
37+
state_dict = weights.get_state_dict(progress=progress)
3838
for key in list(state_dict.keys()):
3939
res = pattern.match(key)
4040
if res:
@@ -63,11 +63,11 @@ def _densenet(
6363
return model
6464

6565

66-
_common_meta = {
66+
_COMMON_META = {
6767
"size": (224, 224),
6868
"categories": _IMAGENET_CATEGORIES,
6969
"interpolation": InterpolationMode.BILINEAR,
70-
"recipe": None, # weights ported from LuaTorch
70+
"recipe": None, # TODO: add here a URL to documentation stating that the weights were ported from LuaTorch
7171
}
7272

7373

@@ -76,7 +76,7 @@ class DenseNet121Weights(Weights):
7676
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
7777
transforms=partial(ImageNetEval, crop_size=224),
7878
meta={
79-
**_common_meta,
79+
**_COMMON_META,
8080
"acc@1": 74.434,
8181
"acc@5": 91.972,
8282
},
@@ -88,7 +88,7 @@ class DenseNet161Weights(Weights):
8888
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
8989
transforms=partial(ImageNetEval, crop_size=224),
9090
meta={
91-
**_common_meta,
91+
**_COMMON_META,
9292
"acc@1": 77.138,
9393
"acc@5": 93.560,
9494
},
@@ -100,7 +100,7 @@ class DenseNet169Weights(Weights):
100100
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
101101
transforms=partial(ImageNetEval, crop_size=224),
102102
meta={
103-
**_common_meta,
103+
**_COMMON_META,
104104
"acc@1": 75.600,
105105
"acc@5": 92.806,
106106
},
@@ -112,7 +112,7 @@ class DenseNet201Weights(Weights):
112112
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
113113
transforms=partial(ImageNetEval, crop_size=224),
114114
meta={
115-
**_common_meta,
115+
**_COMMON_META,
116116
"acc@1": 76.896,
117117
"acc@5": 93.370,
118118
},
@@ -121,7 +121,7 @@ class DenseNet201Weights(Weights):
121121

122122
def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
123123
if "pretrained" in kwargs:
124-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
124+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
125125
weights = DenseNet121Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
126126
weights = DenseNet121Weights.verify(weights)
127127

@@ -130,7 +130,7 @@ def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = T
130130

131131
def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
132132
if "pretrained" in kwargs:
133-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
133+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
134134
weights = DenseNet161Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
135135
weights = DenseNet161Weights.verify(weights)
136136

@@ -139,7 +139,7 @@ def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = T
139139

140140
def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
141141
if "pretrained" in kwargs:
142-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
142+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
143143
weights = DenseNet169Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
144144
weights = DenseNet169Weights.verify(weights)
145145

@@ -148,7 +148,7 @@ def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = T
148148

149149
def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
150150
if "pretrained" in kwargs:
151-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
151+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
152152
weights = DenseNet201Weights.ImageNet1K_Community if kwargs.pop("pretrained") else None
153153
weights = DenseNet201Weights.verify(weights)
154154

torchvision/prototype/models/detection/faster_rcnn.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
]
3131

3232

33-
_common_meta = {
33+
_COMMON_META = {
3434
"categories": _COCO_CATEGORIES,
3535
"interpolation": InterpolationMode.BILINEAR,
3636
}
@@ -41,7 +41,7 @@ class FasterRCNNResNet50FPNWeights(Weights):
4141
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
4242
transforms=CocoEval,
4343
meta={
44-
**_common_meta,
44+
**_COMMON_META,
4545
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
4646
"map": 37.0,
4747
},
@@ -53,7 +53,7 @@ class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
5353
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
5454
transforms=CocoEval,
5555
meta={
56-
**_common_meta,
56+
**_COMMON_META,
5757
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
5858
"map": 32.8,
5959
},
@@ -65,7 +65,7 @@ class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
6565
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
6666
transforms=CocoEval,
6767
meta={
68-
**_common_meta,
68+
**_COMMON_META,
6969
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
7070
"map": 22.8,
7171
},
@@ -81,11 +81,11 @@ def fasterrcnn_resnet50_fpn(
8181
**kwargs: Any,
8282
) -> FasterRCNN:
8383
if "pretrained" in kwargs:
84-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
84+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
8585
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
8686
weights = FasterRCNNResNet50FPNWeights.verify(weights)
8787
if "pretrained_backbone" in kwargs:
88-
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
88+
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
8989
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
9090
weights_backbone = ResNet50Weights.verify(weights_backbone)
9191

@@ -102,7 +102,7 @@ def fasterrcnn_resnet50_fpn(
102102
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
103103

104104
if weights is not None:
105-
model.load_state_dict(weights.state_dict(progress=progress))
105+
model.load_state_dict(weights.get_state_dict(progress=progress))
106106
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
107107
overwrite_eps(model, 0.0)
108108

@@ -142,7 +142,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
142142
)
143143

144144
if weights is not None:
145-
model.load_state_dict(weights.state_dict(progress=progress))
145+
model.load_state_dict(weights.get_state_dict(progress=progress))
146146

147147
return model
148148

@@ -156,11 +156,11 @@ def fasterrcnn_mobilenet_v3_large_fpn(
156156
**kwargs: Any,
157157
) -> FasterRCNN:
158158
if "pretrained" in kwargs:
159-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
159+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
160160
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
161161
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
162162
if "pretrained_backbone" in kwargs:
163-
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
163+
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
164164
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
165165
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
166166

@@ -188,11 +188,11 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
188188
**kwargs: Any,
189189
) -> FasterRCNN:
190190
if "pretrained" in kwargs:
191-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
191+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
192192
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
193193
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
194194
if "pretrained_backbone" in kwargs:
195-
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
195+
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
196196
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
197197
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
198198

torchvision/prototype/models/detection/keypoint_rcnn.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
]
2323

2424

25-
_common_meta = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}
25+
_COMMON_META = {"categories": _COCO_PERSON_CATEGORIES, "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES}
2626

2727

2828
class KeypointRCNNResNet50FPNWeights(Weights):
2929
Coco_RefV1_Legacy = WeightEntry(
3030
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
3131
transforms=CocoEval,
3232
meta={
33-
**_common_meta,
33+
**_COMMON_META,
3434
"recipe": "https://github.com/pytorch/vision/issues/1606",
3535
"box_map": 50.6,
3636
"kp_map": 61.1,
@@ -40,7 +40,7 @@ class KeypointRCNNResNet50FPNWeights(Weights):
4040
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
4141
transforms=CocoEval,
4242
meta={
43-
**_common_meta,
43+
**_COMMON_META,
4444
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
4545
"box_map": 54.6,
4646
"kp_map": 65.0,
@@ -58,7 +58,7 @@ def keypointrcnn_resnet50_fpn(
5858
**kwargs: Any,
5959
) -> KeypointRCNN:
6060
if "pretrained" in kwargs:
61-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
61+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
6262
pretrained = kwargs.pop("pretrained")
6363
if type(pretrained) == str and pretrained == "legacy":
6464
weights = KeypointRCNNResNet50FPNWeights.Coco_RefV1_Legacy
@@ -68,7 +68,7 @@ def keypointrcnn_resnet50_fpn(
6868
weights = None
6969
weights = KeypointRCNNResNet50FPNWeights.verify(weights)
7070
if "pretrained_backbone" in kwargs:
71-
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
71+
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
7272
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
7373
weights_backbone = ResNet50Weights.verify(weights_backbone)
7474

@@ -86,7 +86,7 @@ def keypointrcnn_resnet50_fpn(
8686
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
8787

8888
if weights is not None:
89-
model.load_state_dict(weights.state_dict(progress=progress))
89+
model.load_state_dict(weights.get_state_dict(progress=progress))
9090
if weights == KeypointRCNNResNet50FPNWeights.Coco_RefV1:
9191
overwrite_eps(model, 0.0)
9292

torchvision/prototype/models/detection/mask_rcnn.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ def maskrcnn_resnet50_fpn(
4646
**kwargs: Any,
4747
) -> MaskRCNN:
4848
if "pretrained" in kwargs:
49-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
49+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
5050
weights = MaskRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
5151
weights = MaskRCNNResNet50FPNWeights.verify(weights)
5252
if "pretrained_backbone" in kwargs:
53-
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
53+
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
5454
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
5555
weights_backbone = ResNet50Weights.verify(weights_backbone)
5656

@@ -67,7 +67,7 @@ def maskrcnn_resnet50_fpn(
6767
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
6868

6969
if weights is not None:
70-
model.load_state_dict(weights.state_dict(progress=progress))
70+
model.load_state_dict(weights.get_state_dict(progress=progress))
7171
if weights == MaskRCNNResNet50FPNWeights.Coco_RefV1:
7272
overwrite_eps(model, 0.0)
7373

torchvision/prototype/models/detection/retinanet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ def retinanet_resnet50_fpn(
4646
**kwargs: Any,
4747
) -> RetinaNet:
4848
if "pretrained" in kwargs:
49-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
49+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
5050
weights = RetinaNetResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
5151
weights = RetinaNetResNet50FPNWeights.verify(weights)
5252
if "pretrained_backbone" in kwargs:
53-
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
53+
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
5454
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
5555
weights_backbone = ResNet50Weights.verify(weights_backbone)
5656

@@ -70,7 +70,7 @@ def retinanet_resnet50_fpn(
7070
model = RetinaNet(backbone, num_classes, **kwargs)
7171

7272
if weights is not None:
73-
model.load_state_dict(weights.state_dict(progress=progress))
73+
model.load_state_dict(weights.get_state_dict(progress=progress))
7474
if weights == RetinaNetResNet50FPNWeights.Coco_RefV1:
7575
overwrite_eps(model, 0.0)
7676

torchvision/prototype/models/detection/ssd.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ def ssd300_vgg16(
4444
**kwargs: Any,
4545
) -> SSD:
4646
if "pretrained" in kwargs:
47-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
47+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
4848
weights = SSD300VGG16Weights.Coco_RefV1 if kwargs.pop("pretrained") else None
4949
weights = SSD300VGG16Weights.verify(weights)
5050
if "pretrained_backbone" in kwargs:
51-
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
51+
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
5252
weights_backbone = VGG16Weights.ImageNet1K_Features if kwargs.pop("pretrained_backbone") else None
5353
weights_backbone = VGG16Weights.verify(weights_backbone)
5454

5555
if "size" in kwargs:
56-
warnings.warn("The size of the model is already fixed; ignoring the argument.")
56+
warnings.warn("The size of the model is already fixed; ignoring the parameter.")
5757

5858
if weights is not None:
5959
weights_backbone = None
@@ -81,6 +81,6 @@ def ssd300_vgg16(
8181
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
8282

8383
if weights is not None:
84-
model.load_state_dict(weights.state_dict(progress=progress))
84+
model.load_state_dict(weights.get_state_dict(progress=progress))
8585

8686
return model

torchvision/prototype/models/detection/ssdlite.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,16 @@ def ssdlite320_mobilenet_v3_large(
5050
**kwargs: Any,
5151
) -> SSD:
5252
if "pretrained" in kwargs:
53-
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
53+
warnings.warn("The parameter pretrained is deprecated, please use weights instead.")
5454
weights = SSDlite320MobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
5555
weights = SSDlite320MobileNetV3LargeFPNWeights.verify(weights)
5656
if "pretrained_backbone" in kwargs:
57-
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
57+
warnings.warn("The parameter pretrained_backbone is deprecated, please use weights_backbone instead.")
5858
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
5959
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
6060

6161
if "size" in kwargs:
62-
warnings.warn("The size of the model is already fixed; ignoring the argument.")
62+
warnings.warn("The size of the model is already fixed; ignoring the parameter.")
6363

6464
if weights is not None:
6565
weights_backbone = None
@@ -114,6 +114,6 @@ def ssdlite320_mobilenet_v3_large(
114114
)
115115

116116
if weights is not None:
117-
model.load_state_dict(weights.state_dict(progress=progress))
117+
model.load_state_dict(weights.get_state_dict(progress=progress))
118118

119119
return model

0 commit comments

Comments
 (0)