Skip to content

Commit 99ec261

Browse files
vfdev-5NicolasHug
andauthored
Resize V2 relies on interpolate's native uint8 handling (#7557)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent fc838ad commit 99ec261

File tree

5 files changed

+111
-22
lines changed

5 files changed

+111
-22
lines changed

test/common_utils.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -465,11 +465,15 @@ def load(self, device):
465465
class ImageLoader(TensorLoader):
466466
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
467467
num_channels: int = dataclasses.field(init=False)
468+
memory_format: torch.memory_format = torch.contiguous_format
468469

469470
def __post_init__(self):
470471
self.spatial_size = self.shape[-2:]
471472
self.num_channels = self.shape[-3]
472473

474+
def load(self, device):
475+
return self.fn(self.shape, self.dtype, device, memory_format=self.memory_format)
476+
473477

474478
NUM_CHANNELS_MAP = {
475479
"GRAY": 1,
@@ -493,18 +497,21 @@ def make_image_loader(
493497
extra_dims=(),
494498
dtype=torch.float32,
495499
constant_alpha=True,
500+
memory_format=torch.contiguous_format,
496501
):
497502
size = _parse_spatial_size(size)
498503
num_channels = get_num_channels(color_space)
499504

500-
def fn(shape, dtype, device):
505+
def fn(shape, dtype, device, memory_format):
501506
max_value = get_max_value(dtype)
502-
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
507+
data = torch.testing.make_tensor(
508+
shape, low=0, high=max_value, dtype=dtype, device=device, memory_format=memory_format
509+
)
503510
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
504511
data[..., -1, :, :] = max_value
505512
return datapoints.Image(data)
506513

507-
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype)
514+
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, memory_format=memory_format)
508515

509516

510517
make_image = from_loader(make_image_loader)
@@ -530,11 +537,13 @@ def make_image_loaders(
530537
make_images = from_loaders(make_image_loaders)
531538

532539

533-
def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8):
540+
def make_image_loader_for_interpolation(
541+
size="random", *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
542+
):
534543
size = _parse_spatial_size(size)
535544
num_channels = get_num_channels(color_space)
536545

537-
def fn(shape, dtype, device):
546+
def fn(shape, dtype, device, memory_format):
538547
height, width = shape[-2:]
539548

540549
image_pil = (
@@ -550,19 +559,25 @@ def fn(shape, dtype, device):
550559
)
551560
)
552561

553-
image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
562+
image_tensor = to_image_tensor(image_pil)
563+
if memory_format == torch.contiguous_format:
564+
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
565+
else:
566+
image_tensor = image_tensor.to(device=device)
567+
image_tensor = convert_dtype_image_tensor(image_tensor, dtype=dtype)
554568

555569
return datapoints.Image(image_tensor)
556570

557-
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype)
571+
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format)
558572

559573

560574
def make_image_loaders_for_interpolation(
561575
sizes=((233, 147),),
562576
color_spaces=("RGB",),
563577
dtypes=(torch.uint8,),
578+
memory_formats=(torch.contiguous_format, torch.channels_last),
564579
):
565-
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
580+
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes, memory_format=memory_formats):
566581
yield make_image_loader_for_interpolation(**params)
567582

568583

@@ -744,8 +759,10 @@ def make_video_loader(
744759
size = _parse_spatial_size(size)
745760
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
746761

747-
def fn(shape, dtype, device):
748-
video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device)
762+
def fn(shape, dtype, device, memory_format):
763+
video = make_image(
764+
size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device, memory_format=memory_format
765+
)
749766
return datapoints.Video(video)
750767

751768
return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)

test/test_transforms_v2_consistency.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def __init__(
9898
ArgsKwargs((29, 32), antialias=False),
9999
ArgsKwargs((28, 31), antialias=True),
100100
],
101+
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
102+
closeness_kwargs=dict(rtol=0, atol=1),
101103
),
102104
ConsistencyConfig(
103105
v2_transforms.CenterCrop,
@@ -313,6 +315,8 @@ def __init__(
313315
ArgsKwargs((29, 32), antialias=False),
314316
ArgsKwargs((28, 31), antialias=True),
315317
],
318+
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
319+
closeness_kwargs=dict(rtol=0, atol=1),
316320
),
317321
ConsistencyConfig(
318322
v2_transforms.RandomErasing,
@@ -783,7 +787,8 @@ def test_compose(self):
783787
]
784788
)
785789

786-
check_call_consistency(prototype_transform, legacy_transform)
790+
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
791+
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
787792

788793
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
789794
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
@@ -807,7 +812,8 @@ def test_random_apply(self, p, sequence_type):
807812
p=p,
808813
)
809814

810-
check_call_consistency(prototype_transform, legacy_transform)
815+
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
816+
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
811817

812818
if sequence_type is nn.ModuleList:
813819
# quick and dirty test that it is jit-scriptable
@@ -832,7 +838,8 @@ def test_random_choice(self, probabilities):
832838
p=probabilities,
833839
)
834840

835-
check_call_consistency(prototype_transform, legacy_transform)
841+
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
842+
check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
836843

837844

838845
class TestToTensorTransforms:

test/test_transforms_v2_functional.py

+30
Original file line numberDiff line numberDiff line change
@@ -1365,3 +1365,33 @@ def test_correctness_uniform_temporal_subsample(device):
13651365

13661366
out_video = F.uniform_temporal_subsample(video, 8)
13671367
assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]
1368+
1369+
1370+
# TODO: We can remove this test and related torchvision workaround
1371+
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
1372+
@make_info_args_kwargs_parametrization(
1373+
[info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
1374+
args_kwargs_fn=lambda info: info.reference_inputs_fn(),
1375+
)
1376+
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
1377+
(input, *other_args), kwargs = args_kwargs.load("cpu")
1378+
1379+
output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
1380+
1381+
error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
1382+
assert input.ndim == 3, error_msg_fn
1383+
input_stride = input.stride()
1384+
output_stride = output.stride()
1385+
# Here we check output memory format according to the input:
1386+
# if input_stride is (..., 1) then input is most likely channels first and thus
1387+
# output strides should match channels first strides (H * W, H, 1)
1388+
# if input_stride is (1, ...) then input is most likely channels last and thus
1389+
# output strides should match channels last strides (1, W * C, C)
1390+
if input_stride[-1] == 1:
1391+
expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
1392+
assert expected_stride == output_stride, error_msg_fn("")
1393+
elif input_stride[0] == 1:
1394+
expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
1395+
assert expected_stride == output_stride, error_msg_fn("")
1396+
else:
1397+
assert False, error_msg_fn("")

test/transforms_v2_kernel_infos.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1569,31 +1569,35 @@ def reference_inputs_equalize_image_tensor():
15691569
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
15701570
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
15711571
# the information gain is low if we already provide something really close to the expected value.
1572-
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor):
1572+
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
15731573
if dtype.is_floating_point:
15741574
low = low_factor
15751575
high = high_factor
15761576
else:
15771577
max_value = torch.iinfo(dtype).max
15781578
low = int(low_factor * max_value)
15791579
high = int(high_factor * max_value)
1580-
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
1580+
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
1581+
memory_format=memory_format, copy=True
1582+
)
15811583

1582-
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta):
1584+
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
15831585
image = torch.distributions.Beta(alpha, beta).sample(shape)
15841586
if not dtype.is_floating_point:
15851587
image.mul_(torch.iinfo(dtype).max).round_()
1586-
return image.to(dtype=dtype, device=device)
1588+
return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)
15871589

15881590
spatial_size = (256, 256)
15891591
for dtype, color_space, fn in itertools.product(
15901592
[torch.uint8],
15911593
["GRAY", "RGB"],
15921594
[
1593-
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
1594-
lambda shape, dtype, device: torch.full(
1595-
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
1595+
lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
1596+
memory_format=memory_format, copy=True
15961597
),
1598+
lambda shape, dtype, device, memory_format: torch.full(
1599+
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
1600+
).to(memory_format=memory_format, copy=True),
15971601
*[
15981602
functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
15991603
for low_factor, high_factor in [

torchvision/transforms/v2/functional/_geometry.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,47 @@ def resize_image_tensor(
176176
antialias = False
177177

178178
shape = image.shape
179+
numel = image.numel()
179180
num_channels, old_height, old_width = shape[-3:]
180181
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
181182

182183
if (new_height, new_width) == (old_height, old_width):
183184
return image
184-
elif image.numel() > 0:
185+
elif numel > 0:
185186
image = image.reshape(-1, num_channels, old_height, old_width)
186187

187188
dtype = image.dtype
188-
need_cast = dtype not in (torch.float32, torch.float64)
189+
acceptable_dtypes = [torch.float32, torch.float64]
190+
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
191+
# uint8 dtype can be included for cpu and cuda input if nearest mode
192+
acceptable_dtypes.append(torch.uint8)
193+
elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu":
194+
# uint8 dtype support for bilinear mode is limited to cpu and
195+
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
196+
if "AVX2" in torch.backends.cpu.get_cpu_capability():
197+
acceptable_dtypes.append(torch.uint8)
198+
199+
# TODO: Remove when https://github.com/pytorch/pytorch/pull/101136 is landed
200+
if dtype == torch.uint8 and not (
201+
image.is_contiguous() or image.is_contiguous(memory_format=torch.channels_last)
202+
):
203+
image = image.contiguous(memory_format=torch.channels_last)
204+
205+
strides = image.stride()
206+
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
207+
# There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
208+
# contiguous even though the input is un-ambiguously channels_last (https://github.com/pytorch/pytorch/issues/68430).
209+
# In particular this happens for the typical torchvision use-case of single CHW images where we fake the batch dim
210+
# to become 1CHW. Below, we restride those tensors to trick torch core into properly allocating the output as
211+
# channels_last, thus preserving the memory format of the input. This is not just for format consistency:
212+
# for uint8 bilinear images, this also avoids an extra copy (re-packing) of the output and saves time.
213+
# TODO: when https://github.com/pytorch/pytorch/issues/68430 is fixed (possibly by https://github.com/pytorch/pytorch/pull/100373),
214+
# we should be able to remove this hack.
215+
new_strides = list(strides)
216+
new_strides[0] = numel
217+
image = image.as_strided((1, num_channels, old_height, old_width), new_strides)
218+
219+
need_cast = dtype not in acceptable_dtypes
189220
if need_cast:
190221
image = image.to(dtype=torch.float32)
191222

0 commit comments

Comments
 (0)