Skip to content

Commit fd2d42a

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Fix some annotations in transforms v2 for JIT v1 compatibility (#7252)
Reviewed By: vmoens Differential Revision: D44416629 fbshipit-source-id: ab4950cc6c3d313355f29c069838fb96fe9a2dbf
1 parent ad660ec commit fd2d42a

File tree

10 files changed

+84
-113
lines changed

10 files changed

+84
-113
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torchvision.prototype.transforms.functional as F
55
from prototype_common_utils import InfoBase, TestMark
6-
from prototype_transforms_kernel_infos import KERNEL_INFOS
6+
from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
77
from torchvision.prototype import datapoints
88

99
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
@@ -96,25 +96,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
9696
)
9797

9898

99-
def xfail_jit_tuple_instead_of_list(name, *, reason=None):
100-
return xfail_jit(
101-
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting",
102-
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
103-
)
104-
105-
106-
def is_list_of_ints(args_kwargs):
107-
fill = args_kwargs.kwargs.get("fill")
108-
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)
109-
110-
111-
def xfail_jit_list_of_ints(name, *, reason=None):
112-
return xfail_jit(
113-
reason or f"Passing a list of integers for `{name}` is not supported when scripting",
114-
condition=is_list_of_ints,
115-
)
116-
117-
11899
skip_dispatch_datapoint = TestMark(
119100
("TestDispatchers", "test_dispatch_datapoint"),
120101
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
@@ -130,6 +111,13 @@ def xfail_jit_list_of_ints(name, *, reason=None):
130111
multi_crop_skips.append(skip_dispatch_datapoint)
131112

132113

114+
def xfails_pil(reason, *, condition=None):
115+
return [
116+
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
117+
for test_name in ["test_dispatch_pil", "test_pil_output_type"]
118+
]
119+
120+
133121
def fill_sequence_needs_broadcast(args_kwargs):
134122
(image_loader, *_), kwargs = args_kwargs
135123
try:
@@ -143,11 +131,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
143131
return image_loader.num_channels > 1
144132

145133

146-
xfail_dispatch_pil_if_fill_sequence_needs_broadcast = TestMark(
147-
("TestDispatchers", "test_dispatch_pil"),
148-
pytest.mark.xfail(
149-
reason="PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger."
150-
),
134+
xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
135+
"PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
151136
condition=fill_sequence_needs_broadcast,
152137
)
153138

@@ -186,11 +171,9 @@ def fill_sequence_needs_broadcast(args_kwargs):
186171
},
187172
pil_kernel_info=PILKernelInfo(F.affine_image_pil),
188173
test_marks=[
189-
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
174+
*xfails_pil_if_fill_sequence_needs_broadcast,
190175
xfail_jit_python_scalar_arg("shear"),
191-
xfail_jit_tuple_instead_of_list("fill"),
192-
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
193-
xfail_jit_list_of_ints("fill"),
176+
xfail_jit_python_scalar_arg("fill"),
194177
],
195178
),
196179
DispatcherInfo(
@@ -213,9 +196,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
213196
},
214197
pil_kernel_info=PILKernelInfo(F.rotate_image_pil),
215198
test_marks=[
216-
xfail_jit_tuple_instead_of_list("fill"),
217-
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
218-
xfail_jit_list_of_ints("fill"),
199+
xfail_jit_python_scalar_arg("fill"),
200+
*xfails_pil_if_fill_sequence_needs_broadcast,
219201
],
220202
),
221203
DispatcherInfo(
@@ -248,21 +230,16 @@ def fill_sequence_needs_broadcast(args_kwargs):
248230
},
249231
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"),
250232
test_marks=[
251-
TestMark(
252-
("TestDispatchers", "test_dispatch_pil"),
253-
pytest.mark.xfail(
254-
reason=(
255-
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
256-
"`padding_mode='constant'`, if the number of color channels is larger."
257-
)
233+
*xfails_pil(
234+
reason=(
235+
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
236+
"`padding_mode='constant'`, if the number of color channels is larger."
258237
),
259238
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
260239
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
261240
),
262-
xfail_jit_tuple_instead_of_list("padding"),
263-
xfail_jit_tuple_instead_of_list("fill"),
264-
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
265-
xfail_jit_list_of_ints("fill"),
241+
xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
242+
xfail_jit_python_scalar_arg("padding"),
266243
],
267244
),
268245
DispatcherInfo(
@@ -275,7 +252,8 @@ def fill_sequence_needs_broadcast(args_kwargs):
275252
},
276253
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
277254
test_marks=[
278-
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
255+
*xfails_pil_if_fill_sequence_needs_broadcast,
256+
xfail_jit_python_scalar_arg("fill"),
279257
],
280258
),
281259
DispatcherInfo(
@@ -287,6 +265,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
287265
datapoints.Mask: F.elastic_mask,
288266
},
289267
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
268+
test_marks=[xfail_jit_python_scalar_arg("fill")],
290269
),
291270
DispatcherInfo(
292271
F.center_crop,

test/prototype_transforms_kernel_infos.py

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -153,26 +153,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
153153
)
154154

155155

156-
def xfail_jit_tuple_instead_of_list(name, *, reason=None):
157-
reason = reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting"
158-
return xfail_jit(
159-
reason or f"Passing a tuple instead of a list for `{name}` is not supported when scripting",
160-
condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), tuple),
161-
)
162-
163-
164-
def is_list_of_ints(args_kwargs):
165-
fill = args_kwargs.kwargs.get("fill")
166-
return isinstance(fill, list) and any(isinstance(scalar_fill, int) for scalar_fill in fill)
167-
168-
169-
def xfail_jit_list_of_ints(name, *, reason=None):
170-
return xfail_jit(
171-
reason or f"Passing a list of integers for `{name}` is not supported when scripting",
172-
condition=is_list_of_ints,
173-
)
174-
175-
176156
KERNEL_INFOS = []
177157

178158

@@ -450,21 +430,21 @@ def _full_affine_params(**partial_params):
450430
]
451431

452432

453-
def get_fills(*, num_channels, dtype, vector=True):
433+
def get_fills(*, num_channels, dtype):
454434
yield None
455435

456-
max_value = get_max_value(dtype)
457-
# This intentionally gives us a float and an int scalar fill value
458-
yield max_value / 2
459-
yield max_value
436+
int_value = get_max_value(dtype)
437+
float_value = int_value / 2
438+
yield int_value
439+
yield float_value
460440

461-
if not vector:
462-
return
441+
for vector_type in [list, tuple]:
442+
yield vector_type([int_value])
443+
yield vector_type([float_value])
463444

464-
if dtype.is_floating_point:
465-
yield [0.1 + c / 10 for c in range(num_channels)]
466-
else:
467-
yield [12.0 + c for c in range(num_channels)]
445+
if num_channels > 1:
446+
yield vector_type(float_value * c / 10 for c in range(num_channels))
447+
yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))
468448

469449

470450
def float32_vs_uint8_fill_adapter(other_args, kwargs):
@@ -644,9 +624,7 @@ def sample_inputs_affine_video():
644624
closeness_kwargs=pil_reference_pixel_difference(10, mae=True),
645625
test_marks=[
646626
xfail_jit_python_scalar_arg("shear"),
647-
xfail_jit_tuple_instead_of_list("fill"),
648-
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
649-
xfail_jit_list_of_ints("fill"),
627+
xfail_jit_python_scalar_arg("fill"),
650628
],
651629
),
652630
KernelInfo(
@@ -873,9 +851,7 @@ def sample_inputs_rotate_video():
873851
float32_vs_uint8=True,
874852
closeness_kwargs=pil_reference_pixel_difference(1, mae=True),
875853
test_marks=[
876-
xfail_jit_tuple_instead_of_list("fill"),
877-
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
878-
xfail_jit_list_of_ints("fill"),
854+
xfail_jit_python_scalar_arg("fill"),
879855
],
880856
),
881857
KernelInfo(
@@ -1122,12 +1098,14 @@ def reference_inputs_pad_image_tensor():
11221098
for image_loader, params in itertools.product(
11231099
make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
11241100
):
1125-
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
11261101
for fill in get_fills(
11271102
num_channels=image_loader.num_channels,
11281103
dtype=image_loader.dtype,
1129-
vector=params["padding_mode"] == "constant",
11301104
):
1105+
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
1106+
if isinstance(fill, (list, tuple)):
1107+
continue
1108+
11311109
yield ArgsKwargs(image_loader, fill=fill, **params)
11321110

11331111

@@ -1195,6 +1173,16 @@ def reference_inputs_pad_bounding_box():
11951173
)
11961174

11971175

1176+
def pad_xfail_jit_fill_condition(args_kwargs):
1177+
fill = args_kwargs.kwargs.get("fill")
1178+
if not isinstance(fill, (list, tuple)):
1179+
return False
1180+
elif isinstance(fill, tuple):
1181+
return True
1182+
else: # isinstance(fill, list):
1183+
return all(isinstance(f, int) for f in fill)
1184+
1185+
11981186
KERNEL_INFOS.extend(
11991187
[
12001188
KernelInfo(
@@ -1205,10 +1193,10 @@ def reference_inputs_pad_bounding_box():
12051193
float32_vs_uint8=float32_vs_uint8_fill_adapter,
12061194
closeness_kwargs=float32_vs_uint8_pixel_difference(),
12071195
test_marks=[
1208-
xfail_jit_tuple_instead_of_list("padding"),
1209-
xfail_jit_tuple_instead_of_list("fill"),
1210-
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
1211-
xfail_jit_list_of_ints("fill"),
1196+
xfail_jit_python_scalar_arg("padding"),
1197+
xfail_jit(
1198+
"F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition
1199+
),
12121200
],
12131201
),
12141202
KernelInfo(
@@ -1217,7 +1205,7 @@ def reference_inputs_pad_bounding_box():
12171205
reference_fn=reference_pad_bounding_box,
12181206
reference_inputs_fn=reference_inputs_pad_bounding_box,
12191207
test_marks=[
1220-
xfail_jit_tuple_instead_of_list("padding"),
1208+
xfail_jit_python_scalar_arg("padding"),
12211209
],
12221210
),
12231211
KernelInfo(
@@ -1261,8 +1249,11 @@ def reference_inputs_perspective_image_tensor():
12611249
F.InterpolationMode.BILINEAR,
12621250
],
12631251
):
1264-
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
12651252
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
1253+
# FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
1254+
if isinstance(fill, (list, tuple)):
1255+
continue
1256+
12661257
yield ArgsKwargs(
12671258
image_loader,
12681259
startpoints=None,
@@ -1327,6 +1318,7 @@ def sample_inputs_perspective_video():
13271318
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
13281319
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
13291320
},
1321+
test_marks=[xfail_jit_python_scalar_arg("fill")],
13301322
),
13311323
KernelInfo(
13321324
F.perspective_bounding_box,
@@ -1418,6 +1410,7 @@ def sample_inputs_elastic_video():
14181410
**float32_vs_uint8_pixel_difference(6, mae=True),
14191411
**cuda_vs_cpu_pixel_difference(),
14201412
},
1413+
test_marks=[xfail_jit_python_scalar_arg("fill")],
14211414
),
14221415
KernelInfo(
14231416
F.elastic_bounding_box,

torchvision/prototype/datapoints/_bounding_box.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def resized_crop(
118118
def pad(
119119
self,
120120
padding: Union[int, Sequence[int]],
121-
fill: FillTypeJIT = None,
121+
fill: Optional[Union[int, float, List[float]]] = None,
122122
padding_mode: str = "constant",
123123
) -> BoundingBox:
124124
output, spatial_size = self._F.pad_bounding_box(

torchvision/prototype/datapoints/_datapoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
D = TypeVar("D", bound="Datapoint")
1414
FillType = Union[int, float, Sequence[int], Sequence[float], None]
15-
FillTypeJIT = Union[int, float, List[float], None]
15+
FillTypeJIT = Optional[List[float]]
1616

1717

1818
class Datapoint(torch.Tensor):
@@ -169,8 +169,8 @@ def resized_crop(
169169

170170
def pad(
171171
self,
172-
padding: Union[int, List[int]],
173-
fill: FillTypeJIT = None,
172+
padding: List[int],
173+
fill: Optional[Union[int, float, List[float]]] = None,
174174
padding_mode: str = "constant",
175175
) -> Datapoint:
176176
return self

torchvision/prototype/datapoints/_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def resized_crop(
103103

104104
def pad(
105105
self,
106-
padding: Union[int, List[int]],
107-
fill: FillTypeJIT = None,
106+
padding: List[int],
107+
fill: Optional[Union[int, float, List[float]]] = None,
108108
padding_mode: str = "constant",
109109
) -> Image:
110110
output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)

torchvision/prototype/datapoints/_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def resized_crop(
8383

8484
def pad(
8585
self,
86-
padding: Union[int, List[int]],
87-
fill: FillTypeJIT = None,
86+
padding: List[int],
87+
fill: Optional[Union[int, float, List[float]]] = None,
8888
padding_mode: str = "constant",
8989
) -> Mask:
9090
output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)

torchvision/prototype/datapoints/_video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def resized_crop(
102102

103103
def pad(
104104
self,
105-
padding: Union[int, List[int]],
106-
fill: FillTypeJIT = None,
105+
padding: List[int],
106+
fill: Optional[Union[int, float, List[float]]] = None,
107107
padding_mode: str = "constant",
108108
) -> Video:
109109
output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(
270270

271271
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
272272
fill = self.fill[type(inpt)]
273-
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
273+
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
274274

275275

276276
class RandomZoomOut(_RandomApplyTransform):

torchvision/prototype/transforms/_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,9 @@ def _convert_fill_arg(fill: datapoints.FillType) -> datapoints.FillTypeJIT:
6060
if fill is None:
6161
return fill
6262

63-
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
6463
if not isinstance(fill, (int, float)):
6564
fill = [float(v) for v in list(fill)]
66-
return fill
65+
return fill # type: ignore[return-value]
6766

6867

6968
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillTypeJIT]:

0 commit comments

Comments
 (0)