diff --git a/test/common_utils.py b/test/common_utils.py index 385bc670a2b..920e999863f 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -369,3 +369,16 @@ def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_meth err < tol, msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10]) ) + + +def cycle_over(objs): + for idx, obj in enumerate(objs): + yield obj, objs[:idx] + objs[idx + 1:] + + +def int_dtypes(): + return torch.testing.integral_types() + + +def float_dtypes(): + return torch.testing.floating_types() diff --git a/test/test_transforms.py b/test/test_transforms.py index 76d70b82ed6..c97c985ca46 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -20,24 +20,11 @@ except ImportError: stats = None -GRACE_HOPPER = get_file_path_2( - os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') - - -def cycle_over(objs): - objs = list(objs) - for idx, obj in enumerate(objs): - yield obj, objs[:idx] + objs[idx + 1:] +from common_utils import cycle_over, int_dtypes, float_dtypes -def int_dtypes(): - yield from iter( - (torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,) - ) - - -def float_dtypes(): - yield from iter((torch.float32, torch.float, torch.float64, torch.double)) +GRACE_HOPPER = get_file_path_2( + os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') class Tester(unittest.TestCase): diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 0455b46d4b2..7ec603a9a90 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -9,7 +9,7 @@ import unittest -from common_utils import TransformsTester, get_tmp_dir +from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes class Tester(TransformsTester): @@ -27,14 +27,14 @@ def _test_functional_op(self, func, fn_kwargs): transformed_pil_img = f(pil_img, **fn_kwargs) self.compareTensorToPIL(transformed_tensor, transformed_pil_img) - def _test_transform_vs_scripted(self, transform, s_transform, tensor): + def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None): torch.manual_seed(12) out1 = transform(tensor) torch.manual_seed(12) out2 = s_transform(tensor) - self.assertTrue(out1.equal(out2)) + self.assertTrue(out1.equal(out2), msg=msg) - def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors): + def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None): torch.manual_seed(12) transformed_batch = transform(batch_tensors) @@ -42,11 +42,11 @@ def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_ten img_tensor = batch_tensors[i, ...] torch.manual_seed(12) transformed_img = transform(img_tensor) - self.assertTrue(transformed_img.equal(transformed_batch[i, ...])) + self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), msg=msg) torch.manual_seed(12) s_transformed_batch = s_transform(batch_tensors) - self.assertTrue(transformed_batch.equal(s_transformed_batch)) + self.assertTrue(transformed_batch.equal(s_transformed_batch), msg=msg) def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs): if meth_kwargs is None: @@ -492,6 +492,32 @@ def test_random_erasing(self): self._test_transform_vs_scripted(fn, scripted_fn, tensor) self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) + def test_convert_image_dtype(self): + tensor, _ = self._create_data(26, 34, device=self.device) + batch_tensors = torch.rand(4, 3, 44, 56, device=self.device) + + for in_dtype in int_dtypes() + float_dtypes(): + in_tensor = tensor.to(in_dtype) + in_batch_tensors = batch_tensors.to(in_dtype) + for out_dtype in int_dtypes() + float_dtypes(): + + fn = T.ConvertImageDtype(dtype=out_dtype) + scripted_fn = torch.jit.script(fn) + + if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \ + (in_dtype == torch.float64 and out_dtype == torch.int64): + with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"): + self._test_transform_vs_scripted(fn, scripted_fn, in_tensor) + with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"): + self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors) + continue + + self._test_transform_vs_scripted(fn, scripted_fn, in_tensor) + self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors) + + with get_tmp_dir() as tmp_dir: + scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt")) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester):