Skip to content

Commit f71c430

Browse files
authored
simplify dispatcher if-elif (#7084)
1 parent 69ae61a commit f71c430

File tree

12 files changed

+64
-107
lines changed

12 files changed

+64
-107
lines changed

mypy.ini

-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ no_implicit_optional = True
3232

3333
; warnings
3434
warn_unused_ignores = True
35-
warn_return_any = True
3635

3736
; miscellaneous strictness flags
3837
allow_redefinition = True

torchvision/prototype/transforms/_type_conversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class ToImageTensor(Transform):
4646
def _transform(
4747
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
4848
) -> datapoints.Image:
49-
return F.to_image_tensor(inpt) # type: ignore[no-any-return]
49+
return F.to_image_tensor(inpt)
5050

5151

5252
class ToImagePIL(Transform):

torchvision/prototype/transforms/functional/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# TODO: Add _log_api_usage_once() in all mid-level kernels. If they remain not jit-scriptable we can use decorators
22

33
from torchvision.transforms import InterpolationMode # usort: skip
4+
5+
from ._utils import is_simple_tensor # usort: skip
6+
47
from ._meta import (
58
clamp_bounding_box,
69
convert_format_bounding_box,

torchvision/prototype/transforms/functional/_augment.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
88
from torchvision.utils import _log_api_usage_once
99

10+
from ._utils import is_simple_tensor
11+
1012

1113
def erase_image_tensor(
1214
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
@@ -45,9 +47,7 @@ def erase(
4547
if not torch.jit.is_scripting():
4648
_log_api_usage_once(erase)
4749

48-
if isinstance(inpt, torch.Tensor) and (
49-
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
50-
):
50+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
5151
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
5252
elif isinstance(inpt, datapoints.Image):
5353
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)

torchvision/prototype/transforms/functional/_color.py

+11-28
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchvision.utils import _log_api_usage_once
99

1010
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
11+
from ._utils import is_simple_tensor
1112

1213

1314
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
@@ -43,9 +44,7 @@ def adjust_brightness(inpt: datapoints.InputTypeJIT, brightness_factor: float) -
4344
if not torch.jit.is_scripting():
4445
_log_api_usage_once(adjust_brightness)
4546

46-
if isinstance(inpt, torch.Tensor) and (
47-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
48-
):
47+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
4948
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
5049
elif isinstance(inpt, datapoints._datapoint.Datapoint):
5150
return inpt.adjust_brightness(brightness_factor=brightness_factor)
@@ -131,9 +130,7 @@ def adjust_contrast(inpt: datapoints.InputTypeJIT, contrast_factor: float) -> da
131130
if not torch.jit.is_scripting():
132131
_log_api_usage_once(adjust_contrast)
133132

134-
if isinstance(inpt, torch.Tensor) and (
135-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
136-
):
133+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
137134
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
138135
elif isinstance(inpt, datapoints._datapoint.Datapoint):
139136
return inpt.adjust_contrast(contrast_factor=contrast_factor)
@@ -326,9 +323,7 @@ def adjust_hue(inpt: datapoints.InputTypeJIT, hue_factor: float) -> datapoints.I
326323
if not torch.jit.is_scripting():
327324
_log_api_usage_once(adjust_hue)
328325

329-
if isinstance(inpt, torch.Tensor) and (
330-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
331-
):
326+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
332327
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
333328
elif isinstance(inpt, datapoints._datapoint.Datapoint):
334329
return inpt.adjust_hue(hue_factor=hue_factor)
@@ -371,9 +366,7 @@ def adjust_gamma(inpt: datapoints.InputTypeJIT, gamma: float, gain: float = 1) -
371366
if not torch.jit.is_scripting():
372367
_log_api_usage_once(adjust_gamma)
373368

374-
if isinstance(inpt, torch.Tensor) and (
375-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
376-
):
369+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
377370
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
378371
elif isinstance(inpt, datapoints._datapoint.Datapoint):
379372
return inpt.adjust_gamma(gamma=gamma, gain=gain)
@@ -410,9 +403,7 @@ def posterize(inpt: datapoints.InputTypeJIT, bits: int) -> datapoints.InputTypeJ
410403
if not torch.jit.is_scripting():
411404
_log_api_usage_once(posterize)
412405

413-
if isinstance(inpt, torch.Tensor) and (
414-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
415-
):
406+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
416407
return posterize_image_tensor(inpt, bits=bits)
417408
elif isinstance(inpt, datapoints._datapoint.Datapoint):
418409
return inpt.posterize(bits=bits)
@@ -443,9 +434,7 @@ def solarize(inpt: datapoints.InputTypeJIT, threshold: float) -> datapoints.Inpu
443434
if not torch.jit.is_scripting():
444435
_log_api_usage_once(solarize)
445436

446-
if isinstance(inpt, torch.Tensor) and (
447-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
448-
):
437+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
449438
return solarize_image_tensor(inpt, threshold=threshold)
450439
elif isinstance(inpt, datapoints._datapoint.Datapoint):
451440
return inpt.solarize(threshold=threshold)
@@ -498,9 +487,7 @@ def autocontrast(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
498487
if not torch.jit.is_scripting():
499488
_log_api_usage_once(autocontrast)
500489

501-
if isinstance(inpt, torch.Tensor) and (
502-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
503-
):
490+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
504491
return autocontrast_image_tensor(inpt)
505492
elif isinstance(inpt, datapoints._datapoint.Datapoint):
506493
return inpt.autocontrast()
@@ -593,9 +580,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
593580
if not torch.jit.is_scripting():
594581
_log_api_usage_once(equalize)
595582

596-
if isinstance(inpt, torch.Tensor) and (
597-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
598-
):
583+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
599584
return equalize_image_tensor(inpt)
600585
elif isinstance(inpt, datapoints._datapoint.Datapoint):
601586
return inpt.equalize()
@@ -610,7 +595,7 @@ def equalize(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
610595

611596
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
612597
if image.is_floating_point():
613-
return 1.0 - image # type: ignore[no-any-return]
598+
return 1.0 - image
614599
elif image.dtype == torch.uint8:
615600
return image.bitwise_not()
616601
else: # signed integer dtypes
@@ -629,9 +614,7 @@ def invert(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
629614
if not torch.jit.is_scripting():
630615
_log_api_usage_once(invert)
631616

632-
if isinstance(inpt, torch.Tensor) and (
633-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
634-
):
617+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
635618
return invert_image_tensor(inpt)
636619
elif isinstance(inpt, datapoints._datapoint.Datapoint):
637620
return inpt.invert()

torchvision/prototype/transforms/functional/_deprecated.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from torchvision.prototype import datapoints
88
from torchvision.transforms import functional as _F
99

10+
from ._utils import is_simple_tensor
11+
1012

1113
@torch.jit.unused
1214
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
@@ -25,14 +27,14 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
2527
def rgb_to_grayscale(
2628
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
2729
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
28-
if not torch.jit.is_scripting() and isinstance(inpt, (datapoints.Image, datapoints.Video)):
29-
inpt = inpt.as_subclass(torch.Tensor)
30-
old_color_space = None
31-
elif isinstance(inpt, torch.Tensor):
30+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
3231
old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
3332
else:
3433
old_color_space = None
3534

35+
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
36+
inpt = inpt.as_subclass(torch.Tensor)
37+
3638
call = ", num_output_channels=3" if num_output_channels == 3 else ""
3739
replacement = (
3840
f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"

torchvision/prototype/transforms/functional/_geometry.py

+15-39
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
2525

26+
from ._utils import is_simple_tensor
27+
2628

2729
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
2830
return image.flip(-1)
@@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
6062
if not torch.jit.is_scripting():
6163
_log_api_usage_once(horizontal_flip)
6264

63-
if isinstance(inpt, torch.Tensor) and (
64-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
65-
):
65+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
6666
return horizontal_flip_image_tensor(inpt)
6767
elif isinstance(inpt, datapoints._datapoint.Datapoint):
6868
return inpt.horizontal_flip()
@@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
111111
if not torch.jit.is_scripting():
112112
_log_api_usage_once(vertical_flip)
113113

114-
if isinstance(inpt, torch.Tensor) and (
115-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
116-
):
114+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
117115
return vertical_flip_image_tensor(inpt)
118116
elif isinstance(inpt, datapoints._datapoint.Datapoint):
119117
return inpt.vertical_flip()
@@ -241,9 +239,7 @@ def resize(
241239
) -> datapoints.InputTypeJIT:
242240
if not torch.jit.is_scripting():
243241
_log_api_usage_once(resize)
244-
if isinstance(inpt, torch.Tensor) and (
245-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
246-
):
242+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
247243
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
248244
elif isinstance(inpt, datapoints._datapoint.Datapoint):
249245
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
@@ -744,9 +740,7 @@ def affine(
744740
_log_api_usage_once(affine)
745741

746742
# TODO: consider deprecating integers from angle and shear on the future
747-
if isinstance(inpt, torch.Tensor) and (
748-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
749-
):
743+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
750744
return affine_image_tensor(
751745
inpt,
752746
angle,
@@ -929,9 +923,7 @@ def rotate(
929923
if not torch.jit.is_scripting():
930924
_log_api_usage_once(rotate)
931925

932-
if isinstance(inpt, torch.Tensor) and (
933-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
934-
):
926+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
935927
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
936928
elif isinstance(inpt, datapoints._datapoint.Datapoint):
937929
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
@@ -1139,9 +1131,7 @@ def pad(
11391131
if not torch.jit.is_scripting():
11401132
_log_api_usage_once(pad)
11411133

1142-
if isinstance(inpt, torch.Tensor) and (
1143-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
1144-
):
1134+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
11451135
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
11461136

11471137
elif isinstance(inpt, datapoints._datapoint.Datapoint):
@@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width:
12191209
if not torch.jit.is_scripting():
12201210
_log_api_usage_once(crop)
12211211

1222-
if isinstance(inpt, torch.Tensor) and (
1223-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
1224-
):
1212+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
12251213
return crop_image_tensor(inpt, top, left, height, width)
12261214
elif isinstance(inpt, datapoints._datapoint.Datapoint):
12271215
return inpt.crop(top, left, height, width)
@@ -1476,9 +1464,7 @@ def perspective(
14761464
) -> datapoints.InputTypeJIT:
14771465
if not torch.jit.is_scripting():
14781466
_log_api_usage_once(perspective)
1479-
if isinstance(inpt, torch.Tensor) and (
1480-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
1481-
):
1467+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
14821468
return perspective_image_tensor(
14831469
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
14841470
)
@@ -1639,9 +1625,7 @@ def elastic(
16391625
if not torch.jit.is_scripting():
16401626
_log_api_usage_once(elastic)
16411627

1642-
if isinstance(inpt, torch.Tensor) and (
1643-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
1644-
):
1628+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
16451629
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
16461630
elif isinstance(inpt, datapoints._datapoint.Datapoint):
16471631
return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
@@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo
17541738
if not torch.jit.is_scripting():
17551739
_log_api_usage_once(center_crop)
17561740

1757-
if isinstance(inpt, torch.Tensor) and (
1758-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
1759-
):
1741+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
17601742
return center_crop_image_tensor(inpt, output_size)
17611743
elif isinstance(inpt, datapoints._datapoint.Datapoint):
17621744
return inpt.center_crop(output_size)
@@ -1850,9 +1832,7 @@ def resized_crop(
18501832
if not torch.jit.is_scripting():
18511833
_log_api_usage_once(resized_crop)
18521834

1853-
if isinstance(inpt, torch.Tensor) and (
1854-
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint)
1855-
):
1835+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
18561836
return resized_crop_image_tensor(
18571837
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
18581838
)
@@ -1935,9 +1915,7 @@ def five_crop(
19351915

19361916
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
19371917
# `ten_crop`
1938-
if isinstance(inpt, torch.Tensor) and (
1939-
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
1940-
):
1918+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
19411919
return five_crop_image_tensor(inpt, size)
19421920
elif isinstance(inpt, datapoints.Image):
19431921
output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
@@ -1991,9 +1969,7 @@ def ten_crop(
19911969
if not torch.jit.is_scripting():
19921970
_log_api_usage_once(ten_crop)
19931971

1994-
if isinstance(inpt, torch.Tensor) and (
1995-
torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video))
1996-
):
1972+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
19971973
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
19981974
elif isinstance(inpt, datapoints.Image):
19991975
output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)

0 commit comments

Comments
 (0)