Skip to content

Commit c2e8a00

Browse files
pmeierPhilip Meier
and
Philip Meier
authored
Add convert_image_dtype to functionals (#2078)
* add convert_image_dtype to functionals * add ConvertImageDtype transform * add test * remove underscores from numbers since they are not compatible with python<3.6 * address review comments 1/3 * fix torch.bool * use torch.iinfo in test * fix flake8 * remove double conversion * fix flake9 * bug fix * add error messages to test * disable torch.float16 and torch.half for now * add docstring * add test for consistency * move nested function to top * test in CI * dirty progress * add int to int and cleanup * lint Co-authored-by: Philip Meier <[email protected]>
1 parent 54da5db commit c2e8a00

File tree

3 files changed

+198
-4
lines changed

3 files changed

+198
-4
lines changed

test/test_transforms.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,22 @@
2323
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
2424

2525

26+
def cycle_over(objs):
27+
objs = list(objs)
28+
for idx, obj in enumerate(objs):
29+
yield obj, objs[:idx] + objs[idx + 1:]
30+
31+
32+
def int_dtypes():
33+
yield from iter(
34+
(torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,)
35+
)
36+
37+
38+
def float_dtypes():
39+
yield from iter((torch.float32, torch.float, torch.float64, torch.double))
40+
41+
2642
class Tester(unittest.TestCase):
2743

2844
def test_crop(self):
@@ -510,6 +526,100 @@ def test_to_tensor(self):
510526
output = trans(img)
511527
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
512528

529+
def test_convert_image_dtype_float_to_float(self):
530+
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
531+
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
532+
for output_dtype in output_dtypes:
533+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
534+
transform = transforms.ConvertImageDtype(output_dtype)
535+
output_image = transform(input_image)
536+
537+
actual_min, actual_max = output_image.tolist()
538+
desired_min, desired_max = 0.0, 1.0
539+
540+
self.assertAlmostEqual(actual_min, desired_min)
541+
self.assertAlmostEqual(actual_max, desired_max)
542+
543+
def test_convert_image_dtype_float_to_int(self):
544+
for input_dtype in float_dtypes():
545+
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
546+
for output_dtype in int_dtypes():
547+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
548+
transform = transforms.ConvertImageDtype(output_dtype)
549+
550+
if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
551+
input_dtype == torch.float64 and output_dtype == torch.int64
552+
):
553+
with self.assertRaises(RuntimeError):
554+
transform(input_image)
555+
else:
556+
output_image = transform(input_image)
557+
558+
actual_min, actual_max = output_image.tolist()
559+
desired_min, desired_max = 0, torch.iinfo(output_dtype).max
560+
561+
self.assertEqual(actual_min, desired_min)
562+
self.assertEqual(actual_max, desired_max)
563+
564+
def test_convert_image_dtype_int_to_float(self):
565+
for input_dtype in int_dtypes():
566+
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
567+
for output_dtype in float_dtypes():
568+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
569+
transform = transforms.ConvertImageDtype(output_dtype)
570+
output_image = transform(input_image)
571+
572+
actual_min, actual_max = output_image.tolist()
573+
desired_min, desired_max = 0.0, 1.0
574+
575+
self.assertAlmostEqual(actual_min, desired_min)
576+
self.assertGreaterEqual(actual_min, desired_min)
577+
self.assertAlmostEqual(actual_max, desired_max)
578+
self.assertLessEqual(actual_max, desired_max)
579+
580+
def test_convert_image_dtype_int_to_int(self):
581+
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
582+
input_max = torch.iinfo(input_dtype).max
583+
input_image = torch.tensor((0, input_max), dtype=input_dtype)
584+
for output_dtype in output_dtypes:
585+
output_max = torch.iinfo(output_dtype).max
586+
587+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
588+
transform = transforms.ConvertImageDtype(output_dtype)
589+
output_image = transform(input_image)
590+
591+
actual_min, actual_max = output_image.tolist()
592+
desired_min, desired_max = 0, output_max
593+
594+
# see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
595+
if input_max >= output_max:
596+
error_term = 0
597+
else:
598+
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)
599+
600+
self.assertEqual(actual_min, desired_min)
601+
self.assertEqual(actual_max, desired_max + error_term)
602+
603+
def test_convert_image_dtype_int_to_int_consistency(self):
604+
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
605+
input_max = torch.iinfo(input_dtype).max
606+
input_image = torch.tensor((0, input_max), dtype=input_dtype)
607+
for output_dtype in output_dtypes:
608+
output_max = torch.iinfo(output_dtype).max
609+
if output_max <= input_max:
610+
continue
611+
612+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
613+
transform = transforms.ConvertImageDtype(output_dtype)
614+
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
615+
output_image = inverse_transfrom(transform(input_image))
616+
617+
actual_min, actual_max = output_image.tolist()
618+
desired_min, desired_max = 0, input_max
619+
620+
self.assertEqual(actual_min, desired_min)
621+
self.assertEqual(actual_max, desired_max)
622+
513623
@unittest.skipIf(accimage is None, 'accimage not available')
514624
def test_accimage_to_tensor(self):
515625
trans = transforms.ToTensor()

torchvision/transforms/functional.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,65 @@ def pil_to_tensor(pic):
113113
return img
114114

115115

116+
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
117+
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
118+
119+
Args:
120+
image (torch.Tensor): Image to be converted
121+
dtype (torch.dtype): Desired data type of the output
122+
123+
Returns:
124+
(torch.Tensor): Converted image
125+
126+
.. note::
127+
128+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
129+
If converted back and forth, this mismatch has no effect.
130+
131+
Raises:
132+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
133+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
134+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
135+
of the integer ``dtype``.
136+
"""
137+
if image.dtype == dtype:
138+
return image
139+
140+
if image.dtype.is_floating_point:
141+
# float to float
142+
if dtype.is_floating_point:
143+
return image.to(dtype)
144+
145+
# float to int
146+
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
147+
image.dtype == torch.float64 and dtype == torch.int64
148+
):
149+
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
150+
raise RuntimeError(msg)
151+
152+
eps = 1e-3
153+
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
154+
else:
155+
# int to float
156+
if dtype.is_floating_point:
157+
max = torch.iinfo(image.dtype).max
158+
image = image.to(dtype)
159+
return image / max
160+
161+
# int to int
162+
input_max = torch.iinfo(image.dtype).max
163+
output_max = torch.iinfo(dtype).max
164+
165+
if input_max > output_max:
166+
factor = (input_max + 1) // (output_max + 1)
167+
image = image // factor
168+
return image.to(dtype)
169+
else:
170+
factor = (output_max + 1) // (input_max + 1)
171+
image = image.to(dtype)
172+
return image * factor
173+
174+
116175
def to_pil_image(pic, mode=None):
117176
"""Convert a tensor or an ndarray to PIL Image.
118177

torchvision/transforms/transforms.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from . import functional as F
1616

1717

18-
__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
19-
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
20-
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
21-
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
18+
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
19+
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
20+
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
21+
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
2222
"RandomPerspective", "RandomErasing"]
2323

2424
_pil_interpolation_to_str = {
@@ -115,6 +115,31 @@ def __repr__(self):
115115
return self.__class__.__name__ + '()'
116116

117117

118+
class ConvertImageDtype(object):
119+
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
120+
121+
Args:
122+
dtype (torch.dtype): Desired data type of the output
123+
124+
.. note::
125+
126+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
127+
If converted back and forth, this mismatch has no effect.
128+
129+
Raises:
130+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
131+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
132+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
133+
of the integer ``dtype``.
134+
"""
135+
136+
def __init__(self, dtype: torch.dtype) -> None:
137+
self.dtype = dtype
138+
139+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
140+
return F.convert_image_dtype(image, self.dtype)
141+
142+
118143
class ToPILImage(object):
119144
"""Convert a tensor or an ndarray to PIL Image.
120145

0 commit comments

Comments
 (0)