Skip to content

Commit 8324c48

Browse files
NicolasHugvfdev-5
andauthored
Add uint8 bicubic support to ResizeV2 (#7668)
Co-authored-by: vfdev-5 <[email protected]>
1 parent e44bba1 commit 8324c48

File tree

3 files changed

+49
-17
lines changed

3 files changed

+49
-17
lines changed

test/test_transforms_v2_consistency.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,8 @@ def __init__(
8787
ArgsKwargs([32]),
8888
ArgsKwargs((32, 29)),
8989
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
90-
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
9190
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
9291
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
93-
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
9492
NotScriptableArgsKwargs(31, max_size=32),
9593
ArgsKwargs([31], max_size=32),
9694
NotScriptableArgsKwargs(30, max_size=100),
@@ -101,6 +99,15 @@ def __init__(
10199
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
102100
closeness_kwargs=dict(rtol=0, atol=1),
103101
),
102+
ConsistencyConfig(
103+
v2_transforms.Resize,
104+
legacy_transforms.Resize,
105+
[
106+
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
107+
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True),
108+
],
109+
closeness_kwargs=dict(rtol=0, atol=21),
110+
),
104111
ConsistencyConfig(
105112
v2_transforms.CenterCrop,
106113
legacy_transforms.CenterCrop,
@@ -309,15 +316,22 @@ def __init__(
309316
ArgsKwargs(17, scale=(0.3, 0.7)),
310317
ArgsKwargs(25, ratio=(0.5, 1.5)),
311318
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
312-
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
313319
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
314-
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
315320
ArgsKwargs((29, 32), antialias=False),
316321
ArgsKwargs((28, 31), antialias=True),
317322
],
318323
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
319324
closeness_kwargs=dict(rtol=0, atol=1),
320325
),
326+
ConsistencyConfig(
327+
v2_transforms.RandomResizedCrop,
328+
legacy_transforms.RandomResizedCrop,
329+
[
330+
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
331+
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC, antialias=True),
332+
],
333+
closeness_kwargs=dict(rtol=0, atol=21),
334+
),
321335
ConsistencyConfig(
322336
v2_transforms.RandomErasing,
323337
legacy_transforms.RandomErasing,

test/transforms_v2_kernel_infos.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -257,17 +257,20 @@ def sample_inputs_resize_image_tensor():
257257

258258
for image_loader, interpolation in itertools.product(
259259
make_image_loaders(sizes=["random"], color_spaces=["RGB"]),
260-
[
261-
F.InterpolationMode.NEAREST,
262-
F.InterpolationMode.BILINEAR,
263-
F.InterpolationMode.BICUBIC,
264-
],
260+
[F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR],
265261
):
266262
yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation)
267263

268264
yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25)
269265

270266

267+
def sample_inputs_resize_image_tensor_bicubic():
268+
for image_loader, interpolation in itertools.product(
269+
make_image_loaders(sizes=["random"], color_spaces=["RGB"]), [F.InterpolationMode.BICUBIC]
270+
):
271+
yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation)
272+
273+
271274
@pil_reference_wrapper
272275
def reference_resize_image_tensor(*args, **kwargs):
273276
if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
@@ -364,6 +367,21 @@ def reference_inputs_resize_bounding_box():
364367
xfail_jit_python_scalar_arg("size"),
365368
],
366369
),
370+
KernelInfo(
371+
F.resize_image_tensor,
372+
sample_inputs_fn=sample_inputs_resize_image_tensor_bicubic,
373+
reference_fn=reference_resize_image_tensor,
374+
reference_inputs_fn=reference_inputs_resize_image_tensor,
375+
float32_vs_uint8=True,
376+
closeness_kwargs={
377+
**pil_reference_pixel_difference(10, mae=True),
378+
**cuda_vs_cpu_pixel_difference(atol=30),
379+
**float32_vs_uint8_pixel_difference(1, mae=True),
380+
},
381+
test_marks=[
382+
xfail_jit_python_scalar_arg("size"),
383+
],
384+
),
367385
KernelInfo(
368386
F.resize_bounding_box,
369387
sample_inputs_fn=sample_inputs_resize_bounding_box,

torchvision/transforms/v2/functional/_geometry.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,13 @@ def resize_image_tensor(
190190
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
191191
# uint8 dtype can be included for cpu and cuda input if nearest mode
192192
acceptable_dtypes.append(torch.uint8)
193-
elif (
194-
interpolation == InterpolationMode.BILINEAR
195-
and image.device.type == "cpu"
196-
and "AVX2" in torch.backends.cpu.get_cpu_capability()
197-
):
198-
# uint8 dtype support for bilinear mode is limited to cpu and
199-
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
200-
acceptable_dtypes.append(torch.uint8)
193+
elif image.device.type == "cpu":
194+
# uint8 dtype support for bilinear and bicubic is limited to cpu and
195+
# according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
196+
if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
197+
interpolation == InterpolationMode.BICUBIC
198+
):
199+
acceptable_dtypes.append(torch.uint8)
201200

202201
strides = image.stride()
203202
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
@@ -227,6 +226,7 @@ def resize_image_tensor(
227226

228227
if need_cast:
229228
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
229+
# This path is hit on non-AVX archs, or on GPU.
230230
image = image.clamp_(min=0, max=255)
231231
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
232232
image = image.round_()

0 commit comments

Comments
 (0)