Skip to content

Commit b56f17a

Browse files
authored
Added antialias option to transforms.functional.resize (#3761)
* WIP Added antialias option to transforms.functional.resize * Updates according to the review * Excluded these C++ files for iOS build * Added support for mixed downsampling/upsampling * Fixed heap overflow caused by explicit loop unrolling * Applied PR review suggestions - used pytest parametrize instead unittest - cast to scalar_t ptr - removed interpolate aa files for ios/android keeping original cmake version
1 parent d6fee5a commit b56f17a

10 files changed

+661
-8
lines changed

android/ops/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ file(GLOB VISION_SRCS
1414
../../torchvision/csrc/ops/*.h
1515
../../torchvision/csrc/ops/*.cpp)
1616

17+
# Remove interpolate_aa sources as they are temporary code
18+
# see https://github.com/pytorch/vision/pull/3761
19+
# and IndexingUtils.h is unavailable on Android build
20+
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
21+
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.cpp")
22+
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.h")
23+
1724
add_library(${TARGET} SHARED
1825
${VISION_SRCS}
1926
)

ios/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ file(GLOB VISION_SRCS
1111
../torchvision/csrc/ops/*.h
1212
../torchvision/csrc/ops/*.cpp)
1313

14+
# Remove interpolate_aa sources as they are temporary code
15+
# see https://github.com/pytorch/vision/pull/3761
16+
# and using TensorIterator unavailable with iOS
17+
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
18+
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.cpp")
19+
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.h")
20+
1421
add_library(${TARGET} STATIC
1522
${VISION_SRCS}
1623
)

test/test_functional_tensor.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,5 +1018,52 @@ def test_perspective_interpolation_warning(tester):
10181018
tester.assertTrue(res1.equal(res2))
10191019

10201020

1021+
@pytest.mark.parametrize('device', ["cpu", ])
1022+
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
1023+
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
1024+
@pytest.mark.parametrize('interpolation', [BILINEAR, ])
1025+
def test_resize_antialias(device, dt, size, interpolation, tester):
1026+
1027+
if dt == torch.float16 and device == "cpu":
1028+
# skip float16 on CPU case
1029+
return
1030+
1031+
script_fn = torch.jit.script(F.resize)
1032+
tensor, pil_img = tester._create_data(320, 290, device=device)
1033+
1034+
if dt is not None:
1035+
# This is a trivial cast to float of uint8 data to test all cases
1036+
tensor = tensor.to(dt)
1037+
1038+
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
1039+
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
1040+
1041+
tester.assertEqual(
1042+
resized_tensor.size()[1:], resized_pil_img.size[::-1],
1043+
msg=f"{size}, {interpolation}, {dt}"
1044+
)
1045+
1046+
resized_tensor_f = resized_tensor
1047+
# we need to cast to uint8 to compare with PIL image
1048+
if resized_tensor_f.dtype == torch.uint8:
1049+
resized_tensor_f = resized_tensor_f.to(torch.float)
1050+
1051+
tester.approxEqualTensorToPIL(
1052+
resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}"
1053+
)
1054+
tester.approxEqualTensorToPIL(
1055+
resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max",
1056+
msg=f"{size}, {interpolation}, {dt}"
1057+
)
1058+
1059+
if isinstance(size, int):
1060+
script_size = [size, ]
1061+
else:
1062+
script_size = size
1063+
1064+
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True)
1065+
tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}")
1066+
1067+
10211068
if __name__ == '__main__':
10221069
unittest.main()

test/test_transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ def test_resize(self):
348348

349349
self.assertEqual((owidth, oheight), result.size)
350350

351+
with self.assertWarnsRegex(UserWarning, r"Anti-alias option is always applied for PIL Image input"):
352+
t = transforms.Resize(osize, antialias=False)
353+
t(img)
354+
351355
def test_random_crop(self):
352356
height = random.randint(10, 32) * 2
353357
width = random.randint(10, 32) * 2

0 commit comments

Comments
 (0)