Skip to content

Commit ca52568

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Refactor the get_weights API (#5006)
Summary: * Change the `default` weights mechanism to sue Enum aliases. * Change `get_weights` to work with full Enum names and make it public. * Applying improvements from code review. Reviewed By: NicolasHug Differential Revision: D32759199 fbshipit-source-id: 13cfa6201125db29f099d2e3a73260d62341a205
1 parent ce9d9a5 commit ca52568

37 files changed

+140
-146
lines changed

references/classification/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def load_data(traindir, valdir, args):
158158
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
159159
)
160160
else:
161-
fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model]
162-
weights = PM._api.get_weight(fn, args.weights)
161+
weights = PM.get_weight(args.weights)
163162
preprocessing = weights.transforms()
164163

165164
dataset_test = torchvision.datasets.ImageFolder(

references/detection/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def get_transform(train, args):
5353
elif not args.weights:
5454
return presets.DetectionPresetEval()
5555
else:
56-
fn = PM.detection.__dict__[args.model]
57-
weights = PM._api.get_weight(fn, args.weights)
56+
weights = PM.get_weight(args.weights)
5857
return weights.transforms()
5958

6059

references/segmentation/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def get_transform(train, args):
3838
elif not args.weights:
3939
return presets.SegmentationPresetEval(base_size=520)
4040
else:
41-
fn = PM.segmentation.__dict__[args.model]
42-
weights = PM._api.get_weight(fn, args.weights)
41+
weights = PM.get_weight(args.weights)
4342
return weights.transforms()
4443

4544

references/video_classification/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ def main(args):
160160
if not args.weights:
161161
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
162162
else:
163-
fn = PM.video.__dict__[args.model]
164-
weights = PM._api.get_weight(fn, args.weights)
163+
weights = PM.get_weight(args.weights)
165164
transform_test = weights.transforms()
166165

167166
if args.cache_dataset and os.path.exists(cache_path):

test/test_models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222

2323
def get_models_from_module(module):
2424
# TODO add a registration mechanism to torchvision.models
25-
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
25+
return [
26+
v
27+
for k, v in module.__dict__.items()
28+
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
29+
]
2630

2731

2832
@pytest.fixture

test/test_prototype_models.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ def _get_parent_module(model_fn):
2424
return module
2525

2626

27+
def _get_model_weights(model_fn):
28+
module = _get_parent_module(model_fn)
29+
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
30+
try:
31+
return next(
32+
v
33+
for k, v in module.__dict__.items()
34+
if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__
35+
)
36+
except StopIteration:
37+
return None
38+
39+
2740
def _build_model(fn, **kwargs):
2841
try:
2942
model = fn(**kwargs)
@@ -36,24 +49,22 @@ def _build_model(fn, **kwargs):
3649

3750

3851
@pytest.mark.parametrize(
39-
"model_fn, name, weight",
52+
"name, weight",
4053
[
41-
(models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
42-
(models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2),
54+
("ResNet50_Weights.ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
55+
("ResNet50_Weights.default", models.ResNet50_Weights.ImageNet1K_V2),
4356
(
44-
models.quantization.resnet50,
45-
"default",
57+
"ResNet50_QuantizedWeights.default",
4658
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
4759
),
4860
(
49-
models.quantization.resnet50,
50-
"ImageNet1K_FBGEMM_V1",
61+
"ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1",
5162
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
5263
),
5364
],
5465
)
55-
def test_get_weight(model_fn, name, weight):
56-
assert models._api.get_weight(model_fn, name) == weight
66+
def test_get_weight(name, weight):
67+
assert models.get_weight(name) == weight
5768

5869

5970
@pytest.mark.parametrize(
@@ -65,10 +76,9 @@ def test_get_weight(model_fn, name, weight):
6576
+ TM.get_models_from_module(models.video),
6677
)
6778
def test_naming_conventions(model_fn):
68-
model_name = model_fn.__name__
69-
module = _get_parent_module(model_fn)
70-
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
71-
assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name))
79+
weights_enum = _get_model_weights(model_fn)
80+
assert weights_enum is not None
81+
assert len(weights_enum) == 0 or hasattr(weights_enum, "default")
7282

7383

7484
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
from . import quantization
1616
from . import segmentation
1717
from . import video
18+
from ._api import get_weight

torchvision/prototype/models/_api.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import importlib
2+
import inspect
3+
import sys
14
from collections import OrderedDict
25
from dataclasses import dataclass, fields
36
from enum import Enum
4-
from inspect import signature
57
from typing import Any, Callable, Dict
68

79
from ..._internally_replaced_utils import load_state_dict_from_url
@@ -30,7 +32,6 @@ class Weights:
3032
url: str
3133
transforms: Callable
3234
meta: Dict[str, Any]
33-
default: bool
3435

3536

3637
class WeightsEnum(Enum):
@@ -50,7 +51,7 @@ def __init__(self, value: Weights):
5051
def verify(cls, obj: Any) -> Any:
5152
if obj is not None:
5253
if type(obj) is str:
53-
obj = cls.from_str(obj)
54+
obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
5455
elif not isinstance(obj, cls):
5556
raise TypeError(
5657
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
@@ -59,8 +60,8 @@ def verify(cls, obj: Any) -> Any:
5960

6061
@classmethod
6162
def from_str(cls, value: str) -> "WeightsEnum":
62-
for v in cls:
63-
if v._name_ == value or (value == "default" and v.default):
63+
for k, v in cls.__members__.items():
64+
if k == value:
6465
return v
6566
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
6667

@@ -78,41 +79,35 @@ def __getattr__(self, name):
7879
return super().__getattr__(name)
7980

8081

81-
def get_weight(fn: Callable, weight_name: str) -> WeightsEnum:
82+
def get_weight(name: str) -> WeightsEnum:
8283
"""
83-
Gets the weight enum of a specific model builder method and weight name combination.
84+
Gets the weight enum value by its full name. Example: "ResNet50_Weights.ImageNet1K_V1"
8485
8586
Args:
86-
fn (Callable): The builder method used to create the model.
87-
weight_name (str): The name of the weight enum entry of the specific model.
87+
name (str): The name of the weight enum entry.
8888
8989
Returns:
9090
WeightsEnum: The requested weight enum.
9191
"""
92-
sig = signature(fn)
93-
if "weights" not in sig.parameters:
94-
raise ValueError("The method is missing the 'weights' parameter.")
92+
try:
93+
enum_name, value_name = name.split(".")
94+
except ValueError:
95+
raise ValueError(f"Invalid weight name provided: '{name}'.")
96+
97+
base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
98+
base_module = importlib.import_module(base_module_name)
99+
model_modules = [base_module] + [
100+
x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py")
101+
]
95102

96-
ann = signature(fn).parameters["weights"].annotation
97103
weights_enum = None
98-
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
99-
weights_enum = ann
100-
else:
101-
# handle cases like Union[Optional, T]
102-
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
103-
for t in ann.__args__: # type: ignore[union-attr]
104-
if isinstance(t, type) and issubclass(t, WeightsEnum):
105-
# ensure the name exists. handles builders with multiple types of weights like in quantization
106-
try:
107-
t.from_str(weight_name)
108-
except ValueError:
109-
continue
110-
weights_enum = t
111-
break
104+
for m in model_modules:
105+
potential_class = m.__dict__.get(enum_name, None)
106+
if potential_class is not None and issubclass(potential_class, WeightsEnum):
107+
weights_enum = potential_class
108+
break
112109

113110
if weights_enum is None:
114-
raise ValueError(
115-
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
116-
)
111+
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
117112

118-
return weights_enum.from_str(weight_name)
113+
return weights_enum.from_str(value_name)

torchvision/prototype/models/alexnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class AlexNet_Weights(WeightsEnum):
2525
"acc@1": 56.522,
2626
"acc@5": 79.066,
2727
},
28-
default=True,
2928
)
29+
default = ImageNet1K_V1
3030

3131

3232
def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:

torchvision/prototype/models/densenet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ class DenseNet121_Weights(WeightsEnum):
8080
"acc@1": 74.434,
8181
"acc@5": 91.972,
8282
},
83-
default=True,
8483
)
84+
default = ImageNet1K_V1
8585

8686

8787
class DenseNet161_Weights(WeightsEnum):
@@ -93,8 +93,8 @@ class DenseNet161_Weights(WeightsEnum):
9393
"acc@1": 77.138,
9494
"acc@5": 93.560,
9595
},
96-
default=True,
9796
)
97+
default = ImageNet1K_V1
9898

9999

100100
class DenseNet169_Weights(WeightsEnum):
@@ -106,8 +106,8 @@ class DenseNet169_Weights(WeightsEnum):
106106
"acc@1": 75.600,
107107
"acc@5": 92.806,
108108
},
109-
default=True,
110109
)
110+
default = ImageNet1K_V1
111111

112112

113113
class DenseNet201_Weights(WeightsEnum):
@@ -119,8 +119,8 @@ class DenseNet201_Weights(WeightsEnum):
119119
"acc@1": 76.896,
120120
"acc@5": 93.370,
121121
},
122-
default=True,
123122
)
123+
default = ImageNet1K_V1
124124

125125

126126
def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:

torchvision/prototype/models/detection/faster_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
4545
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
4646
"map": 37.0,
4747
},
48-
default=True,
4948
)
49+
default = Coco_V1
5050

5151

5252
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
@@ -58,8 +58,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
5858
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
5959
"map": 32.8,
6060
},
61-
default=True,
6261
)
62+
default = Coco_V1
6363

6464

6565
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
@@ -71,8 +71,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
7171
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
7272
"map": 22.8,
7373
},
74-
default=True,
7574
)
75+
default = Coco_V1
7676

7777

7878
def fasterrcnn_resnet50_fpn(

torchvision/prototype/models/detection/keypoint_rcnn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
3535
"box_map": 50.6,
3636
"kp_map": 61.1,
3737
},
38-
default=False,
3938
)
4039
Coco_V1 = Weights(
4140
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
@@ -46,8 +45,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
4645
"box_map": 54.6,
4746
"kp_map": 65.0,
4847
},
49-
default=True,
5048
)
49+
default = Coco_V1
5150

5251

5352
def keypointrcnn_resnet50_fpn(

torchvision/prototype/models/detection/mask_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
3434
"box_map": 37.9,
3535
"mask_map": 34.6,
3636
},
37-
default=True,
3837
)
38+
default = Coco_V1
3939

4040

4141
def maskrcnn_resnet50_fpn(

torchvision/prototype/models/detection/retinanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
3434
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
3535
"map": 36.4,
3636
},
37-
default=True,
3837
)
38+
default = Coco_V1
3939

4040

4141
def retinanet_resnet50_fpn(

torchvision/prototype/models/detection/ssd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class SSD300_VGG16_Weights(WeightsEnum):
3333
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16",
3434
"map": 25.1,
3535
},
36-
default=True,
3736
)
37+
default = Coco_V1
3838

3939

4040
def ssd300_vgg16(

torchvision/prototype/models/detection/ssdlite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
3838
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
3939
"map": 21.3,
4040
},
41-
default=True,
4241
)
42+
default = Coco_V1
4343

4444

4545
def ssdlite320_mobilenet_v3_large(

0 commit comments

Comments
 (0)