Skip to content

Commit adad1df

Browse files
authored
fix prototype features and functional transforms after review (#5377)
* fix prototype functional transforms after review * address features review
1 parent 1d748ae commit adad1df

File tree

10 files changed

+32
-712
lines changed

10 files changed

+32
-712
lines changed

torchvision/prototype/features/_bounding_box.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Any, Tuple, Union, Optional
24

35
import torch
@@ -15,7 +17,6 @@ class BoundingBoxFormat(StrEnum):
1517

1618

1719
class BoundingBox(Feature):
18-
formats = BoundingBoxFormat
1920
format: BoundingBoxFormat
2021
image_size: Tuple[int, int]
2122

@@ -27,7 +28,7 @@ def __new__(
2728
device: Optional[torch.device] = None,
2829
format: Union[BoundingBoxFormat, str],
2930
image_size: Tuple[int, int],
30-
):
31+
) -> BoundingBox:
3132
bounding_box = super().__new__(cls, data, dtype=dtype, device=device)
3233

3334
if isinstance(format, str):
@@ -37,7 +38,7 @@ def __new__(
3738

3839
return bounding_box
3940

40-
def to_format(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox":
41+
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
4142
# import at runtime to avoid cyclic imports
4243
from torchvision.prototype.transforms.functional import convert_bounding_box_format
4344

torchvision/prototype/features/_image.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import warnings
24
from typing import Any, Optional, Union, Tuple, cast
35

@@ -20,7 +22,6 @@ class ColorSpace(StrEnum):
2022

2123

2224
class Image(Feature):
23-
color_spaces = ColorSpace
2425
color_space: ColorSpace
2526

2627
def __new__(
@@ -79,5 +80,5 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
7980
def show(self) -> None:
8081
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
8182

82-
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> "Image":
83+
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
8384
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))

torchvision/prototype/features/_label.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def __new__(
2626

2727
@classmethod
2828
def from_category(cls, category: str, *, categories: Sequence[str]):
29-
categories = list(categories)
3029
return cls(categories.index(category), categories=categories)
3130

3231
def to_categories(self):
@@ -45,7 +44,7 @@ def __new__(
4544
*,
4645
dtype: Optional[torch.dtype] = None,
4746
device: Optional[torch.device] = None,
48-
like: Optional["Label"] = None,
47+
like: Optional[Label] = None,
4948
categories: Optional[Sequence[str]] = None,
5049
):
5150
one_hot_label = super().__new__(cls, data, dtype=dtype, device=device)
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
from . import functional
22
from .functional import InterpolationMode # usort: skip
33

4-
from ._transform import Transform
5-
from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip
6-
from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop
7-
from ._misc import Identity, Normalize
84
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval

torchvision/prototype/transforms/_container.py

-90
This file was deleted.

torchvision/prototype/transforms/_geometry.py

-138
This file was deleted.

torchvision/prototype/transforms/_misc.py

-39
This file was deleted.

0 commit comments

Comments
 (0)