Skip to content

Commit 71d2bb0

Browse files
authored
improve StrEnum (#5512)
* improve StrEnum * use StrEnum for model weights * fix test * migrate StrEnum to main area
1 parent e6d82f7 commit 71d2bb0

File tree

7 files changed

+28
-31
lines changed

7 files changed

+28
-31
lines changed

test/test_prototype_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_auto_augment(self, transform, input):
126126
(
127127
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
128128
itertools.chain.from_iterable(
129-
fn(color_spaces=["rgb"], dtypes=[torch.float32])
129+
fn(color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32])
130130
for fn in [
131131
make_images,
132132
make_vanilla_tensor_images,

test/test_prototype_transforms_functional.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32):
1515
size = size or torch.randint(16, 33, (2,)).tolist()
1616

17-
if isinstance(color_space, str):
18-
color_space = features.ColorSpace[color_space]
1917
num_channels = {
2018
features.ColorSpace.GRAYSCALE: 1,
2119
features.ColorSpace.RGB: 3,

torchvision/_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import enum
2+
3+
4+
class StrEnumMeta(enum.EnumMeta):
5+
auto = enum.auto
6+
7+
def from_str(self, member: str):
8+
try:
9+
return self[member]
10+
except KeyError:
11+
# TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
12+
# soon as it is migrated.
13+
raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
14+
15+
16+
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
17+
pass

torchvision/prototype/features/_bounding_box.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Tuple, Union, Optional
44

55
import torch
6-
from torchvision.prototype.utils._internal import StrEnum
6+
from torchvision._utils import StrEnum
77

88
from ._feature import _Feature
99

@@ -30,7 +30,7 @@ def __new__(
3030
bounding_box = super().__new__(cls, data, dtype=dtype, device=device)
3131

3232
if isinstance(format, str):
33-
format = BoundingBoxFormat[format]
33+
format = BoundingBoxFormat.from_str(format.upper())
3434

3535
bounding_box._metadata.update(dict(format=format, image_size=image_size))
3636

torchvision/prototype/features/_image.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, Optional, Union, Tuple, cast
55

66
import torch
7-
from torchvision.prototype.utils._internal import StrEnum
7+
from torchvision._utils import StrEnum
88
from torchvision.transforms.functional import to_pil_image
99
from torchvision.utils import draw_bounding_boxes
1010
from torchvision.utils import make_grid
@@ -14,9 +14,9 @@
1414

1515

1616
class ColorSpace(StrEnum):
17-
OTHER = 0
18-
GRAYSCALE = 1
19-
RGB = 3
17+
OTHER = StrEnum.auto()
18+
GRAYSCALE = StrEnum.auto()
19+
RGB = StrEnum.auto()
2020

2121

2222
class Image(_Feature):
@@ -37,7 +37,7 @@ def __new__(
3737
if color_space == ColorSpace.OTHER:
3838
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
3939
elif isinstance(color_space, str):
40-
color_space = ColorSpace[color_space]
40+
color_space = ColorSpace.from_str(color_space.upper())
4141

4242
image._metadata.update(dict(color_space=color_space))
4343

torchvision/prototype/models/_api.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import sys
44
from collections import OrderedDict
55
from dataclasses import dataclass, fields
6-
from enum import Enum
76
from typing import Any, Callable, Dict
87

8+
from torchvision._utils import StrEnum
9+
910
from ..._internally_replaced_utils import load_state_dict_from_url
1011

1112

@@ -34,7 +35,7 @@ class Weights:
3435
meta: Dict[str, Any]
3536

3637

37-
class WeightsEnum(Enum):
38+
class WeightsEnum(StrEnum):
3839
"""
3940
This class is the parent class of all model weights. Each model building method receives an optional `weights`
4041
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
@@ -58,12 +59,6 @@ def verify(cls, obj: Any) -> Any:
5859
)
5960
return obj
6061

61-
@classmethod
62-
def from_str(cls, value: str) -> "WeightsEnum":
63-
if value in cls.__members__:
64-
return cls.__members__[value]
65-
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
66-
6762
def get_state_dict(self, progress: bool) -> OrderedDict:
6863
return load_state_dict_from_url(self.url, progress=progress)
6964

torchvision/prototype/utils/_internal.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import collections.abc
22
import difflib
3-
import enum
43
import functools
54
import inspect
65
import io
@@ -31,7 +30,6 @@
3130
import torch
3231

3332
__all__ = [
34-
"StrEnum",
3533
"sequence_to_str",
3634
"add_suggestion",
3735
"FrozenMapping",
@@ -45,17 +43,6 @@
4543
]
4644

4745

48-
class StrEnumMeta(enum.EnumMeta):
49-
auto = enum.auto
50-
51-
def __getitem__(self, item):
52-
return super().__getitem__(item.upper() if isinstance(item, str) else item)
53-
54-
55-
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
56-
pass
57-
58-
5946
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
6047
if not seq:
6148
return ""

0 commit comments

Comments
 (0)