Skip to content

Cleanup namings of Multi-weights classes and enums #5003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ def _get_original_model(model_fn):
return module.__dict__[model_fn.__name__]


def _get_parent_module(model_fn):
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
module = importlib.import_module(parent_module_name)
return module


def _build_model(fn, **kwargs):
try:
model = fn(**kwargs)
Expand All @@ -29,27 +35,42 @@ def _build_model(fn, **kwargs):
return model.eval()


def get_models_with_module_names(module):
module_name = module.__name__.split(".")[-1]
return [(fn, module_name) for fn in TM.get_models_from_module(module)]


@pytest.mark.parametrize(
"model_fn, name, weight",
[
(models.resnet50, "ImageNet1K_RefV1", models.ResNet50Weights.ImageNet1K_RefV1),
(models.resnet50, "default", models.ResNet50Weights.ImageNet1K_RefV2),
(models.resnet50, "ImageNet1K_V1", models.ResNet50_Weights.ImageNet1K_V1),
(models.resnet50, "default", models.ResNet50_Weights.ImageNet1K_V2),
(
models.quantization.resnet50,
"default",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V2,
),
(
models.quantization.resnet50,
"ImageNet1K_FBGEMM_RefV1",
models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1,
"ImageNet1K_FBGEMM_V1",
models.quantization.ResNet50_QuantizedWeights.ImageNet1K_FBGEMM_V1,
),
],
)
def test_get_weight(model_fn, name, weight):
assert models._api.get_weight(model_fn, name) == weight


@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
)
def test_naming_conventions(model_fn):
model_name = model_fn.__name__
module = _get_parent_module(model_fn)
weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
assert model_name in set(x.replace(weights_name, "").lower() for x in module.__dict__ if x.endswith(weights_name))


@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
Expand Down Expand Up @@ -85,16 +106,16 @@ def test_video_model(model_fn, dev):


@pytest.mark.parametrize(
"model_fn, module_name",
get_models_with_module_names(models)
+ get_models_with_module_names(models.detection)
+ get_models_with_module_names(models.quantization)
+ get_models_with_module_names(models.segmentation)
+ get_models_with_module_names(models.video),
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
)
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
def test_old_vs_new_factory(model_fn, module_name, dev):
def test_old_vs_new_factory(model_fn, dev):
defaults = {
"models": {
"input_shape": (1, 3, 224, 224),
Expand All @@ -114,6 +135,7 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
},
}
model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2]
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models
Expand Down
36 changes: 18 additions & 18 deletions torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from ..._internally_replaced_utils import load_state_dict_from_url


__all__ = ["Weights", "WeightEntry", "get_weight"]
__all__ = ["WeightsEnum", "Weights", "get_weight"]


@dataclass
class WeightEntry:
class Weights:
"""
This class is used to group important attributes associated with the pre-trained weights.

Expand All @@ -33,17 +33,17 @@ class WeightEntry:
default: bool


class Weights(Enum):
class WeightsEnum(Enum):
"""
This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
`WeightEntry`.
`Weights`.

Args:
value (WeightEntry): The data class entry with the weight information.
value (Weights): The data class entry with the weight information.
"""

def __init__(self, value: WeightEntry):
def __init__(self, value: Weights):
self._value_ = value

@classmethod
Expand All @@ -58,7 +58,7 @@ def verify(cls, obj: Any) -> Any:
return obj

@classmethod
def from_str(cls, value: str) -> "Weights":
def from_str(cls, value: str) -> "WeightsEnum":
for v in cls:
if v._name_ == value or (value == "default" and v.default):
return v
Expand All @@ -71,14 +71,14 @@ def __repr__(self):
return f"{self.__class__.__name__}.{self._name_}"

def __getattr__(self, name):
# Be able to fetch WeightEntry attributes directly
for f in fields(WeightEntry):
# Be able to fetch Weights attributes directly
for f in fields(Weights):
if f.name == name:
return object.__getattribute__(self.value, name)
return super().__getattr__(name)


def get_weight(fn: Callable, weight_name: str) -> Weights:
def get_weight(fn: Callable, weight_name: str) -> WeightsEnum:
"""
Gets the weight enum of a specific model builder method and weight name combination.

Expand All @@ -87,32 +87,32 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
weight_name (str): The name of the weight enum entry of the specific model.

Returns:
Weights: The requested weight enum.
WeightsEnum: The requested weight enum.
"""
sig = signature(fn)
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' parameter.")

ann = signature(fn).parameters["weights"].annotation
weights_class = None
if isinstance(ann, type) and issubclass(ann, Weights):
weights_class = ann
weights_enum = None
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
weights_enum = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, Weights):
if isinstance(t, type) and issubclass(t, WeightsEnum):
# ensure the name exists. handles builders with multiple types of weights like in quantization
try:
t.from_str(weight_name)
except ValueError:
continue
weights_class = t
weights_enum = t
break

if weights_class is None:
if weights_enum is None:
raise ValueError(
"The weight class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)

return weights_class.from_str(weight_name)
return weights_enum.from_str(weight_name)
4 changes: 2 additions & 2 deletions torchvision/prototype/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from typing import Any, Dict, Optional, TypeVar

from ._api import Weights
from ._api import WeightsEnum


W = TypeVar("W", bound=Weights)
W = TypeVar("W", bound=WeightsEnum)
V = TypeVar("V")


Expand Down
14 changes: 7 additions & 7 deletions torchvision/prototype/models/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from torchvision.transforms.functional import InterpolationMode

from ...models.alexnet import AlexNet
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param


__all__ = ["AlexNet", "AlexNetWeights", "alexnet"]
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]


class AlexNetWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
class AlexNet_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -29,12 +29,12 @@ class AlexNetWeights(Weights):
)


def alexnet(weights: Optional[AlexNetWeights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
def alexnet(weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNetWeights.ImageNet1K_RefV1)
weights = AlexNetWeights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", AlexNet_Weights.ImageNet1K_V1)
weights = AlexNet_Weights.verify(weights)

if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
Expand Down
54 changes: 27 additions & 27 deletions torchvision/prototype/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@
from torchvision.transforms.functional import InterpolationMode

from ...models.densenet import DenseNet
from ._api import Weights, WeightEntry
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _deprecated_param, _deprecated_positional, _ovewrite_named_param


__all__ = [
"DenseNet",
"DenseNet121Weights",
"DenseNet161Weights",
"DenseNet169Weights",
"DenseNet201Weights",
"DenseNet121_Weights",
"DenseNet161_Weights",
"DenseNet169_Weights",
"DenseNet201_Weights",
"densenet121",
"densenet161",
"densenet169",
"densenet201",
]


def _load_state_dict(model: nn.Module, weights: Weights, progress: bool) -> None:
def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
# '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
Expand All @@ -48,7 +48,7 @@ def _densenet(
growth_rate: int,
block_config: Tuple[int, int, int, int],
num_init_features: int,
weights: Optional[Weights],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> DenseNet:
Expand All @@ -71,8 +71,8 @@ def _densenet(
}


class DenseNet121Weights(Weights):
ImageNet1K_Community = WeightEntry(
class DenseNet121_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -84,8 +84,8 @@ class DenseNet121Weights(Weights):
)


class DenseNet161Weights(Weights):
ImageNet1K_Community = WeightEntry(
class DenseNet161_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -97,8 +97,8 @@ class DenseNet161Weights(Weights):
)


class DenseNet169Weights(Weights):
ImageNet1K_Community = WeightEntry(
class DenseNet169_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -110,8 +110,8 @@ class DenseNet169Weights(Weights):
)


class DenseNet201Weights(Weights):
ImageNet1K_Community = WeightEntry(
class DenseNet201_Weights(WeightsEnum):
ImageNet1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
Expand All @@ -123,41 +123,41 @@ class DenseNet201Weights(Weights):
)


def densenet121(weights: Optional[DenseNet121Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
def densenet121(weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121Weights.ImageNet1K_Community)
weights = DenseNet121Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet121_Weights.ImageNet1K_V1)
weights = DenseNet121_Weights.verify(weights)

return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)


def densenet161(weights: Optional[DenseNet161Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
def densenet161(weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161Weights.ImageNet1K_Community)
weights = DenseNet161Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet161_Weights.ImageNet1K_V1)
weights = DenseNet161_Weights.verify(weights)

return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)


def densenet169(weights: Optional[DenseNet169Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
def densenet169(weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169Weights.ImageNet1K_Community)
weights = DenseNet169Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet169_Weights.ImageNet1K_V1)
weights = DenseNet169_Weights.verify(weights)

return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)


def densenet201(weights: Optional[DenseNet201Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
def densenet201(weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
if type(weights) == bool and weights:
_deprecated_positional(kwargs, "pretrained", "weights", True)
if "pretrained" in kwargs:
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201Weights.ImageNet1K_Community)
weights = DenseNet201Weights.verify(weights)
weights = _deprecated_param(kwargs, "pretrained", "weights", DenseNet201_Weights.ImageNet1K_V1)
weights = DenseNet201_Weights.verify(weights)

return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
Loading