|
23 | 23 |
|
24 | 24 | from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
|
25 | 25 |
|
| 26 | +from ._utils import is_simple_tensor |
| 27 | + |
26 | 28 |
|
27 | 29 | def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
|
28 | 30 | return image.flip(-1)
|
@@ -60,9 +62,7 @@ def horizontal_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
|
60 | 62 | if not torch.jit.is_scripting():
|
61 | 63 | _log_api_usage_once(horizontal_flip)
|
62 | 64 |
|
63 |
| - if isinstance(inpt, torch.Tensor) and ( |
64 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
65 |
| - ): |
| 65 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
66 | 66 | return horizontal_flip_image_tensor(inpt)
|
67 | 67 | elif isinstance(inpt, datapoints._datapoint.Datapoint):
|
68 | 68 | return inpt.horizontal_flip()
|
@@ -111,9 +111,7 @@ def vertical_flip(inpt: datapoints.InputTypeJIT) -> datapoints.InputTypeJIT:
|
111 | 111 | if not torch.jit.is_scripting():
|
112 | 112 | _log_api_usage_once(vertical_flip)
|
113 | 113 |
|
114 |
| - if isinstance(inpt, torch.Tensor) and ( |
115 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
116 |
| - ): |
| 114 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
117 | 115 | return vertical_flip_image_tensor(inpt)
|
118 | 116 | elif isinstance(inpt, datapoints._datapoint.Datapoint):
|
119 | 117 | return inpt.vertical_flip()
|
@@ -241,9 +239,7 @@ def resize(
|
241 | 239 | ) -> datapoints.InputTypeJIT:
|
242 | 240 | if not torch.jit.is_scripting():
|
243 | 241 | _log_api_usage_once(resize)
|
244 |
| - if isinstance(inpt, torch.Tensor) and ( |
245 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
246 |
| - ): |
| 242 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
247 | 243 | return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
|
248 | 244 | elif isinstance(inpt, datapoints._datapoint.Datapoint):
|
249 | 245 | return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias)
|
@@ -744,9 +740,7 @@ def affine(
|
744 | 740 | _log_api_usage_once(affine)
|
745 | 741 |
|
746 | 742 | # TODO: consider deprecating integers from angle and shear on the future
|
747 |
| - if isinstance(inpt, torch.Tensor) and ( |
748 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
749 |
| - ): |
| 743 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
750 | 744 | return affine_image_tensor(
|
751 | 745 | inpt,
|
752 | 746 | angle,
|
@@ -929,9 +923,7 @@ def rotate(
|
929 | 923 | if not torch.jit.is_scripting():
|
930 | 924 | _log_api_usage_once(rotate)
|
931 | 925 |
|
932 |
| - if isinstance(inpt, torch.Tensor) and ( |
933 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
934 |
| - ): |
| 926 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
935 | 927 | return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
|
936 | 928 | elif isinstance(inpt, datapoints._datapoint.Datapoint):
|
937 | 929 | return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
|
@@ -1139,9 +1131,7 @@ def pad(
|
1139 | 1131 | if not torch.jit.is_scripting():
|
1140 | 1132 | _log_api_usage_once(pad)
|
1141 | 1133 |
|
1142 |
| - if isinstance(inpt, torch.Tensor) and ( |
1143 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
1144 |
| - ): |
| 1134 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
1145 | 1135 | return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
|
1146 | 1136 |
|
1147 | 1137 | elif isinstance(inpt, datapoints._datapoint.Datapoint):
|
@@ -1219,9 +1209,7 @@ def crop(inpt: datapoints.InputTypeJIT, top: int, left: int, height: int, width:
|
1219 | 1209 | if not torch.jit.is_scripting():
|
1220 | 1210 | _log_api_usage_once(crop)
|
1221 | 1211 |
|
1222 |
| - if isinstance(inpt, torch.Tensor) and ( |
1223 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
1224 |
| - ): |
| 1212 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
1225 | 1213 | return crop_image_tensor(inpt, top, left, height, width)
|
1226 | 1214 | elif isinstance(inpt, datapoints._datapoint.Datapoint):
|
1227 | 1215 | return inpt.crop(top, left, height, width)
|
@@ -1476,9 +1464,7 @@ def perspective(
|
1476 | 1464 | ) -> datapoints.InputTypeJIT:
|
1477 | 1465 | if not torch.jit.is_scripting():
|
1478 | 1466 | _log_api_usage_once(perspective)
|
1479 |
| - if isinstance(inpt, torch.Tensor) and ( |
1480 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
1481 |
| - ): |
| 1467 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
1482 | 1468 | return perspective_image_tensor(
|
1483 | 1469 | inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
|
1484 | 1470 | )
|
@@ -1639,9 +1625,7 @@ def elastic(
|
1639 | 1625 | if not torch.jit.is_scripting():
|
1640 | 1626 | _log_api_usage_once(elastic)
|
1641 | 1627 |
|
1642 |
| - if isinstance(inpt, torch.Tensor) and ( |
1643 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
1644 |
| - ): |
| 1628 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
1645 | 1629 | return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
|
1646 | 1630 | elif isinstance(inpt, datapoints._datapoint.Datapoint):
|
1647 | 1631 | return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
|
@@ -1754,9 +1738,7 @@ def center_crop(inpt: datapoints.InputTypeJIT, output_size: List[int]) -> datapo
|
1754 | 1738 | if not torch.jit.is_scripting():
|
1755 | 1739 | _log_api_usage_once(center_crop)
|
1756 | 1740 |
|
1757 |
| - if isinstance(inpt, torch.Tensor) and ( |
1758 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
1759 |
| - ): |
| 1741 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
1760 | 1742 | return center_crop_image_tensor(inpt, output_size)
|
1761 | 1743 | elif isinstance(inpt, datapoints._datapoint.Datapoint):
|
1762 | 1744 | return inpt.center_crop(output_size)
|
@@ -1850,9 +1832,7 @@ def resized_crop(
|
1850 | 1832 | if not torch.jit.is_scripting():
|
1851 | 1833 | _log_api_usage_once(resized_crop)
|
1852 | 1834 |
|
1853 |
| - if isinstance(inpt, torch.Tensor) and ( |
1854 |
| - torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) |
1855 |
| - ): |
| 1835 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
1856 | 1836 | return resized_crop_image_tensor(
|
1857 | 1837 | inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
|
1858 | 1838 | )
|
@@ -1935,9 +1915,7 @@ def five_crop(
|
1935 | 1915 |
|
1936 | 1916 | # TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
|
1937 | 1917 | # `ten_crop`
|
1938 |
| - if isinstance(inpt, torch.Tensor) and ( |
1939 |
| - torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) |
1940 |
| - ): |
| 1918 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
1941 | 1919 | return five_crop_image_tensor(inpt, size)
|
1942 | 1920 | elif isinstance(inpt, datapoints.Image):
|
1943 | 1921 | output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
|
@@ -1991,9 +1969,7 @@ def ten_crop(
|
1991 | 1969 | if not torch.jit.is_scripting():
|
1992 | 1970 | _log_api_usage_once(ten_crop)
|
1993 | 1971 |
|
1994 |
| - if isinstance(inpt, torch.Tensor) and ( |
1995 |
| - torch.jit.is_scripting() or not isinstance(inpt, (datapoints.Image, datapoints.Video)) |
1996 |
| - ): |
| 1972 | + if torch.jit.is_scripting() or is_simple_tensor(inpt): |
1997 | 1973 | return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
|
1998 | 1974 | elif isinstance(inpt, datapoints.Image):
|
1999 | 1975 | output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
|
|
0 commit comments