Skip to content

Commit bd19fb8

Browse files
authored
[proto] Added mid-level ops and feature-based ops (#6219)
* Added mid-level ops and feature-based ops * Fixing deadlock in dataloader with circular imports * Added non-scalar fill support workaround for pad * Removed comments * int/float support for fill in pad op * Updated type hints and removed bypass option from mid-level methods * Minor nit fixes
1 parent b3b7448 commit bd19fb8

File tree

13 files changed

+1049
-144
lines changed

13 files changed

+1049
-144
lines changed

.github/workflows/prototype-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@ jobs:
4141

4242
- name: Run prototype tests
4343
shell: bash
44-
run: pytest --durations=20 test/test_prototype_*.py
44+
run: pytest -vvv --durations=20 test/test_prototype_*.py

test/test_prototype_transforms_functional.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,18 @@ def rotate_segmentation_mask():
365365
)
366366

367367

368+
@register_kernel_info_from_sample_inputs_fn
369+
def crop_image_tensor():
370+
for image, top, left, height, width in itertools.product(make_images(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]):
371+
yield SampleInput(
372+
image,
373+
top=top,
374+
left=left,
375+
height=height,
376+
width=width,
377+
)
378+
379+
368380
@register_kernel_info_from_sample_inputs_fn
369381
def crop_bounding_box():
370382
for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]):
@@ -414,6 +426,17 @@ def resized_crop_segmentation_mask():
414426
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)
415427

416428

429+
@register_kernel_info_from_sample_inputs_fn
430+
def pad_image_tensor():
431+
for image, padding, fill, padding_mode in itertools.product(
432+
make_images(),
433+
[[1], [1, 1], [1, 1, 2, 2]], # padding
434+
[12, 12.0], # fill
435+
["constant", "symmetric", "edge", "reflect"], # padding mode,
436+
):
437+
yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode)
438+
439+
417440
@register_kernel_info_from_sample_inputs_fn
418441
def pad_segmentation_mask():
419442
for mask, padding, padding_mode in itertools.product(
@@ -499,6 +522,39 @@ def test_scriptable(kernel):
499522
jit.script(kernel)
500523

501524

525+
# Test below is intended to test mid-level op vs low-level ops it calls
526+
# For example, resize -> resize_image_tensor, resize_bounding_boxes etc
527+
# TODO: Rewrite this tests as sample args may include more or less params
528+
# than needed by functions
529+
@pytest.mark.parametrize(
530+
"func",
531+
[
532+
pytest.param(func, id=name)
533+
for name, func in F.__dict__.items()
534+
if not name.startswith("_")
535+
and callable(func)
536+
and all(
537+
feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"}
538+
)
539+
and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate"}
540+
# We skip 'crop' due to missing 'height' and 'width'
541+
# We skip 'rotate' due to non implemented yet expand=True case for bboxes
542+
],
543+
)
544+
def test_functional_mid_level(func):
545+
finfos = [finfo for finfo in FUNCTIONAL_INFOS if f"{func.__name__}_" in finfo.name]
546+
for finfo in finfos:
547+
for sample_input in finfo.sample_inputs():
548+
expected = finfo(sample_input)
549+
kwargs = dict(sample_input.kwargs)
550+
for key in ["format", "image_size"]:
551+
if key in kwargs:
552+
del kwargs[key]
553+
output = func(*sample_input.args, **kwargs)
554+
torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}")
555+
break
556+
557+
502558
@pytest.mark.parametrize(
503559
("functional_info", "sample_input"),
504560
[

torchvision/prototype/features/_bounding_box.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import Any, Tuple, Union, Optional
3+
from typing import Any, List, Tuple, Union, Optional, Sequence
44

55
import torch
66
from torchvision._utils import StrEnum
7+
from torchvision.transforms import InterpolationMode
78

89
from ._feature import _Feature
910

@@ -69,3 +70,142 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
6970
return BoundingBox.new_like(
7071
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
7172
)
73+
74+
def horizontal_flip(self) -> BoundingBox:
75+
from torchvision.prototype.transforms import functional as _F
76+
77+
output = _F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size)
78+
return BoundingBox.new_like(self, output)
79+
80+
def vertical_flip(self) -> BoundingBox:
81+
from torchvision.prototype.transforms import functional as _F
82+
83+
output = _F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size)
84+
return BoundingBox.new_like(self, output)
85+
86+
def resize( # type: ignore[override]
87+
self,
88+
size: List[int],
89+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
90+
max_size: Optional[int] = None,
91+
antialias: bool = False,
92+
) -> BoundingBox:
93+
from torchvision.prototype.transforms import functional as _F
94+
95+
output = _F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size)
96+
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
97+
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
98+
99+
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
100+
from torchvision.prototype.transforms import functional as _F
101+
102+
output = _F.crop_bounding_box(self, self.format, top, left)
103+
return BoundingBox.new_like(self, output, image_size=(height, width))
104+
105+
def center_crop(self, output_size: List[int]) -> BoundingBox:
106+
from torchvision.prototype.transforms import functional as _F
107+
108+
output = _F.center_crop_bounding_box(
109+
self, format=self.format, output_size=output_size, image_size=self.image_size
110+
)
111+
image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1])
112+
return BoundingBox.new_like(self, output, image_size=image_size)
113+
114+
def resized_crop(
115+
self,
116+
top: int,
117+
left: int,
118+
height: int,
119+
width: int,
120+
size: List[int],
121+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
122+
antialias: bool = False,
123+
) -> BoundingBox:
124+
from torchvision.prototype.transforms import functional as _F
125+
126+
output = _F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
127+
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
128+
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
129+
130+
def pad(
131+
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
132+
) -> BoundingBox:
133+
from torchvision.prototype.transforms import functional as _F
134+
135+
if padding_mode not in ["constant"]:
136+
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
137+
138+
output = _F.pad_bounding_box(self, padding, format=self.format)
139+
140+
# Update output image size:
141+
# TODO: remove the import below and make _parse_pad_padding available
142+
from torchvision.transforms.functional_tensor import _parse_pad_padding
143+
144+
left, top, right, bottom = _parse_pad_padding(padding)
145+
height, width = self.image_size
146+
height += top + bottom
147+
width += left + right
148+
149+
return BoundingBox.new_like(self, output, image_size=(height, width))
150+
151+
def rotate(
152+
self,
153+
angle: float,
154+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
155+
expand: bool = False,
156+
fill: Optional[List[float]] = None,
157+
center: Optional[List[float]] = None,
158+
) -> BoundingBox:
159+
from torchvision.prototype.transforms import functional as _F
160+
161+
output = _F.rotate_bounding_box(
162+
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
163+
)
164+
# TODO: update output image size if expand is True
165+
if expand:
166+
raise RuntimeError("Not yet implemented")
167+
return BoundingBox.new_like(self, output, dtype=output.dtype)
168+
169+
def affine(
170+
self,
171+
angle: float,
172+
translate: List[float],
173+
scale: float,
174+
shear: List[float],
175+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
176+
fill: Optional[List[float]] = None,
177+
center: Optional[List[float]] = None,
178+
) -> BoundingBox:
179+
from torchvision.prototype.transforms import functional as _F
180+
181+
output = _F.affine_bounding_box(
182+
self,
183+
self.format,
184+
self.image_size,
185+
angle,
186+
translate=translate,
187+
scale=scale,
188+
shear=shear,
189+
center=center,
190+
)
191+
return BoundingBox.new_like(self, output, dtype=output.dtype)
192+
193+
def perspective(
194+
self,
195+
perspective_coeffs: List[float],
196+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
197+
fill: Optional[List[float]] = None,
198+
) -> BoundingBox:
199+
from torchvision.prototype.transforms import functional as _F
200+
201+
output = _F.perspective_bounding_box(self, self.format, perspective_coeffs)
202+
return BoundingBox.new_like(self, output, dtype=output.dtype)
203+
204+
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> BoundingBox:
205+
raise TypeError("Erase transformation does not support bounding boxes")
206+
207+
def mixup(self, lam: float) -> BoundingBox:
208+
raise TypeError("Mixup transformation does not support bounding boxes")
209+
210+
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> BoundingBox:
211+
raise TypeError("Cutmix transformation does not support bounding boxes")

torchvision/prototype/features/_feature.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping
1+
from typing import Any, cast, TypeVar, Union, Optional, Type, Callable, List, Tuple, Sequence, Mapping
22

33
import torch
44
from torch._C import _TensorBase, DisableTorchFunction
5-
5+
from torchvision.transforms import InterpolationMode
66

77
F = TypeVar("F", bound="_Feature")
88

@@ -83,3 +83,115 @@ def __torch_function__(
8383
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device)
8484
else:
8585
return output
86+
87+
def horizontal_flip(self) -> Any:
88+
return self
89+
90+
def vertical_flip(self) -> Any:
91+
return self
92+
93+
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
94+
# https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
95+
def resize( # type: ignore[override]
96+
self,
97+
size: List[int],
98+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
99+
max_size: Optional[int] = None,
100+
antialias: bool = False,
101+
) -> Any:
102+
return self
103+
104+
def crop(self, top: int, left: int, height: int, width: int) -> Any:
105+
return self
106+
107+
def center_crop(self, output_size: List[int]) -> Any:
108+
return self
109+
110+
def resized_crop(
111+
self,
112+
top: int,
113+
left: int,
114+
height: int,
115+
width: int,
116+
size: List[int],
117+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
118+
antialias: bool = False,
119+
) -> Any:
120+
return self
121+
122+
def pad(
123+
self, padding: List[int], fill: Union[int, float, Sequence[float]] = 0, padding_mode: str = "constant"
124+
) -> Any:
125+
return self
126+
127+
def rotate(
128+
self,
129+
angle: float,
130+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
131+
expand: bool = False,
132+
fill: Optional[List[float]] = None,
133+
center: Optional[List[float]] = None,
134+
) -> Any:
135+
return self
136+
137+
def affine(
138+
self,
139+
angle: float,
140+
translate: List[float],
141+
scale: float,
142+
shear: List[float],
143+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
144+
fill: Optional[List[float]] = None,
145+
center: Optional[List[float]] = None,
146+
) -> Any:
147+
return self
148+
149+
def perspective(
150+
self,
151+
perspective_coeffs: List[float],
152+
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
153+
fill: Optional[List[float]] = None,
154+
) -> Any:
155+
return self
156+
157+
def adjust_brightness(self, brightness_factor: float) -> Any:
158+
return self
159+
160+
def adjust_saturation(self, saturation_factor: float) -> Any:
161+
return self
162+
163+
def adjust_contrast(self, contrast_factor: float) -> Any:
164+
return self
165+
166+
def adjust_sharpness(self, sharpness_factor: float) -> Any:
167+
return self
168+
169+
def adjust_hue(self, hue_factor: float) -> Any:
170+
return self
171+
172+
def adjust_gamma(self, gamma: float, gain: float = 1) -> Any:
173+
return self
174+
175+
def posterize(self, bits: int) -> Any:
176+
return self
177+
178+
def solarize(self, threshold: float) -> Any:
179+
return self
180+
181+
def autocontrast(self) -> Any:
182+
return self
183+
184+
def equalize(self) -> Any:
185+
return self
186+
187+
def invert(self) -> Any:
188+
return self
189+
190+
def erase(self, i: int, j: int, h: int, w: int, v: torch.Tensor) -> Any:
191+
return self
192+
193+
def mixup(self, lam: float) -> Any:
194+
return self
195+
196+
def cutmix(self, box: Tuple[int, int, int, int], lam_adjusted: float) -> Any:
197+
return self

0 commit comments

Comments
 (0)