Skip to content

Group E test_transforms.py port to pytest #4026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,9 @@ def freeze_rng_state():


def cycle_over(objs):
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]
for idx, obj1 in enumerate(objs):
for obj2 in objs[:idx] + objs[idx + 1:]:
yield obj1, obj2


def int_dtypes():
Expand Down
237 changes: 113 additions & 124 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import itertools
import os
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t
from torch._utils_internal import get_file_path_2
from numpy.testing import assert_array_almost_equal
import unittest
import math
import random
import numpy as np
Expand All @@ -30,126 +27,118 @@
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')


class Tester(unittest.TestCase):

def test_convert_image_dtype_float_to_float(self):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)

def test_convert_image_dtype_float_to_int(self):
for input_dtype in float_dtypes():
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
with self.assertRaises(RuntimeError):
transform(input_image)
else:
output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_float(self):
for input_dtype in int_dtypes():
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertGreaterEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)
self.assertLessEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_int(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

torch.testing.assert_close(
output_image_script,
output_image,
rtol=0.0,
atol=1e-6,
msg="{} vs {}".format(output_image_script, output_image),
)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, output_max

# see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
if input_max >= output_max:
error_term = 0
else:
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max + error_term)

def test_convert_image_dtype_int_to_int_consistency(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max
if output_max <= input_max:
continue

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
output_image = inverse_transfrom(transform(input_image))

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, input_max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)
class TestConvertImageDtype:
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(float_dtypes()))
def test_float_to_float(self, input_dtype, output_dtype):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

assert abs(actual_min - desired_min) < 1e-7
assert abs(actual_max - desired_max) < 1e-7

@pytest.mark.parametrize('input_dtype', float_dtypes())
@pytest.mark.parametrize('output_dtype', int_dtypes())
def test_float_to_int(self, input_dtype, output_dtype):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
with pytest.raises(RuntimeError):
transform(input_image)
else:
output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max

assert actual_min == desired_min
assert actual_max == desired_max

@pytest.mark.parametrize('input_dtype', int_dtypes())
@pytest.mark.parametrize('output_dtype', float_dtypes())
def test_int_to_float(self, input_dtype, output_dtype):
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

assert abs(actual_min - desired_min) < 1e-7
assert actual_min >= desired_min
assert abs(actual_max - desired_max) < 1e-7
assert actual_max <= desired_max

@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes()))
def test_dtype_int_to_int(self, input_dtype, output_dtype):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
output_max = torch.iinfo(output_dtype).max

transform = transforms.ConvertImageDtype(output_dtype)
transform_script = torch.jit.script(F.convert_image_dtype)

output_image = transform(input_image)
output_image_script = transform_script(input_image, output_dtype)

torch.testing.assert_close(
output_image_script,
output_image,
rtol=0.0,
atol=1e-6,
msg="{} vs {}".format(output_image_script, output_image),
)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, output_max

# see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
if input_max >= output_max:
error_term = 0
else:
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)

assert actual_min == desired_min
assert actual_max == (desired_max + error_term)

@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes()))
def test_int_to_int_consistency(self, input_dtype, output_dtype):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)

output_max = torch.iinfo(output_dtype).max
if output_max <= input_max:
return

transform = transforms.ConvertImageDtype(output_dtype)
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
output_image = inverse_transfrom(transform(input_image))

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, input_max

assert actual_min == desired_min
assert actual_max == desired_max


@pytest.mark.skipif(accimage is None, reason="accimage not available")
Expand Down Expand Up @@ -2120,4 +2109,4 @@ def test_random_affine():


if __name__ == '__main__':
unittest.main()
pytest.main([__file__])