diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 52101368800..7bc83d9ffe1 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -126,7 +126,7 @@ def test_auto_augment(self, transform, input): ( transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), itertools.chain.from_iterable( - fn(color_spaces=["rgb"], dtypes=[torch.float32]) + fn(color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]) for fn in [ make_images, make_vanilla_tensor_images, diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 409a855e23f..7802d9626b0 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -14,8 +14,6 @@ def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32): size = size or torch.randint(16, 33, (2,)).tolist() - if isinstance(color_space, str): - color_space = features.ColorSpace[color_space] num_channels = { features.ColorSpace.GRAYSCALE: 1, features.ColorSpace.RGB: 3, diff --git a/torchvision/_utils.py b/torchvision/_utils.py new file mode 100644 index 00000000000..adc27c3a9fc --- /dev/null +++ b/torchvision/_utils.py @@ -0,0 +1,17 @@ +import enum + + +class StrEnumMeta(enum.EnumMeta): + auto = enum.auto + + def from_str(self, member: str): + try: + return self[member] + except KeyError: + # TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as + # soon as it is migrated. + raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None + + +class StrEnum(enum.Enum, metaclass=StrEnumMeta): + pass diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 5b60d7ee55c..658a3691e90 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -3,7 +3,7 @@ from typing import Any, Tuple, Union, Optional import torch -from torchvision.prototype.utils._internal import StrEnum +from torchvision._utils import StrEnum from ._feature import _Feature @@ -30,7 +30,7 @@ def __new__( bounding_box = super().__new__(cls, data, dtype=dtype, device=device) if isinstance(format, str): - format = BoundingBoxFormat[format] + format = BoundingBoxFormat.from_str(format.upper()) bounding_box._metadata.update(dict(format=format, image_size=image_size)) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 5ecc4cbedb7..511036b0677 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -4,7 +4,7 @@ from typing import Any, Optional, Union, Tuple, cast import torch -from torchvision.prototype.utils._internal import StrEnum +from torchvision._utils import StrEnum from torchvision.transforms.functional import to_pil_image from torchvision.utils import draw_bounding_boxes from torchvision.utils import make_grid @@ -14,9 +14,9 @@ class ColorSpace(StrEnum): - OTHER = 0 - GRAYSCALE = 1 - RGB = 3 + OTHER = StrEnum.auto() + GRAYSCALE = StrEnum.auto() + RGB = StrEnum.auto() class Image(_Feature): @@ -37,7 +37,7 @@ def __new__( if color_space == ColorSpace.OTHER: warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") elif isinstance(color_space, str): - color_space = ColorSpace[color_space] + color_space = ColorSpace.from_str(color_space.upper()) image._metadata.update(dict(color_space=color_space)) diff --git a/torchvision/prototype/models/_api.py b/torchvision/prototype/models/_api.py index 4ba0ee05f08..85b280a7dfc 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/prototype/models/_api.py @@ -3,9 +3,10 @@ import sys from collections import OrderedDict from dataclasses import dataclass, fields -from enum import Enum from typing import Any, Callable, Dict +from torchvision._utils import StrEnum + from ..._internally_replaced_utils import load_state_dict_from_url @@ -34,7 +35,7 @@ class Weights: meta: Dict[str, Any] -class WeightsEnum(Enum): +class WeightsEnum(StrEnum): """ This class is the parent class of all model weights. Each model building method receives an optional `weights` parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type @@ -58,12 +59,6 @@ def verify(cls, obj: Any) -> Any: ) return obj - @classmethod - def from_str(cls, value: str) -> "WeightsEnum": - if value in cls.__members__: - return cls.__members__[value] - raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") - def get_state_dict(self, progress: bool) -> OrderedDict: return load_state_dict_from_url(self.url, progress=progress) diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 366a19f2bbc..864bff9ce02 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -1,6 +1,5 @@ import collections.abc import difflib -import enum import functools import inspect import io @@ -31,7 +30,6 @@ import torch __all__ = [ - "StrEnum", "sequence_to_str", "add_suggestion", "FrozenMapping", @@ -45,17 +43,6 @@ ] -class StrEnumMeta(enum.EnumMeta): - auto = enum.auto - - def __getitem__(self, item): - return super().__getitem__(item.upper() if isinstance(item, str) else item) - - -class StrEnum(enum.Enum, metaclass=StrEnumMeta): - pass - - def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: if not seq: return ""