Skip to content

Commit 28e2fbf

Browse files
committed
add int to int and cleanup
1 parent adfb096 commit 28e2fbf

File tree

3 files changed

+143
-75
lines changed

3 files changed

+143
-75
lines changed

test/test_transforms.py

Lines changed: 98 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@
2323
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
2424

2525

26+
def cycle_over(objs):
27+
objs = list(objs)
28+
for idx, obj in enumerate(objs):
29+
yield obj, objs[:idx] + objs[idx + 1:]
30+
31+
def int_dtypes():
32+
yield from iter(
33+
(torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,)
34+
)
35+
36+
def float_dtypes():
37+
yield from iter((torch.float32, torch.float, torch.float64, torch.double))
38+
39+
2640
class Tester(unittest.TestCase):
2741

2842
def test_crop(self):
@@ -510,54 +524,99 @@ def test_to_tensor(self):
510524
output = trans(img)
511525
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
512526

513-
def test_convert_image_dtype(self):
514-
def cycle_over(objs):
515-
objs = list(objs)
516-
for idx, obj in enumerate(objs):
517-
yield obj, objs[:idx] + objs[idx + 1:]
518-
519-
# dtype_max_value = {
520-
# dtype: 1.0
521-
# for dtype in (torch.float32, torch.float, torch.float64, torch.double)#, torch.bool,)
522-
# # torch.float16 and torch.half are disabled for now since they do not support torch.max
523-
# # See https://github.com/pytorch/pytorch/issues/28623#issuecomment-611379051
524-
# # (torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bool, )
525-
# }
526-
dtype_max_value = {}
527-
dtype_max_value.update(
528-
{
529-
dtype: torch.iinfo(dtype).max
530-
for dtype in (
531-
torch.uint8,
532-
torch.int8,
533-
torch.int16,
534-
torch.short,
535-
torch.int32,
536-
torch.int,
537-
torch.int64,
538-
torch.long,
539-
)
540-
}
541-
)
527+
def test_convert_image_dtype_float_to_float(self):
528+
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
529+
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
530+
for output_dtype in output_dtypes:
531+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
532+
transform = transforms.ConvertImageDtype(output_dtype)
533+
output_image = transform(input_image)
534+
535+
actual_min, actual_max = output_image.tolist()
536+
desired_min, desired_max = 0.0, 1.0
537+
538+
self.assertAlmostEqual(actual_min, desired_min)
539+
self.assertAlmostEqual(actual_max, desired_max)
540+
541+
def test_convert_image_dtype_float_to_int(self):
542+
for input_dtype in float_dtypes():
543+
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
544+
for output_dtype in int_dtypes():
545+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
546+
transform = transforms.ConvertImageDtype(output_dtype)
547+
548+
if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
549+
input_dtype == torch.float64 and output_dtype == torch.int64
550+
):
551+
with self.assertRaises(RuntimeError):
552+
transform(input_image)
553+
else:
554+
output_image = transform(input_image)
542555

543-
for input_dtype, output_dtypes in cycle_over(dtype_max_value.keys()):
544-
input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype]
556+
actual_min, actual_max = output_image.tolist()
557+
desired_min, desired_max = 0, torch.iinfo(output_dtype).max
545558

559+
self.assertEqual(actual_min, desired_min)
560+
self.assertEqual(actual_max, desired_max)
561+
562+
def test_convert_image_dtype_int_to_float(self):
563+
for input_dtype in int_dtypes():
564+
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
565+
for output_dtype in float_dtypes():
566+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
567+
transform = transforms.ConvertImageDtype(output_dtype)
568+
output_image = transform(input_image)
569+
570+
actual_min, actual_max = output_image.tolist()
571+
desired_min, desired_max = 0.0, 1.0
572+
573+
self.assertAlmostEqual(actual_min, desired_min)
574+
self.assertGreaterEqual(actual_min, desired_min)
575+
self.assertAlmostEqual(actual_max, desired_max)
576+
self.assertLessEqual(actual_max, desired_max)
577+
578+
def test_convert_image_dtype_int_to_int(self):
579+
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
580+
input_max = torch.iinfo(input_dtype).max
581+
input_image = torch.tensor((0, input_max), dtype=input_dtype)
546582
for output_dtype in output_dtypes:
583+
output_max = torch.iinfo(output_dtype).max
584+
547585
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
548586
transform = transforms.ConvertImageDtype(output_dtype)
549587
output_image = transform(input_image)
550588

551-
actual = output_image.dtype
552-
desired = output_dtype
553-
self.assertEqual(actual, desired)
589+
actual_min, actual_max = output_image.tolist()
590+
desired_min, desired_max = 0, output_max
554591

555-
actual = torch.max(output_image).item()
556-
desired = dtype_max_value[output_dtype]
557-
if output_dtype.is_floating_point:
558-
self.assertAlmostEqual(actual, desired)
592+
# see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
593+
if input_max >= output_max:
594+
error_term = 0
559595
else:
560-
self.assertEqual(actual, desired)
596+
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)
597+
598+
self.assertEqual(actual_min, desired_min)
599+
self.assertEqual(actual_max, desired_max + error_term)
600+
601+
def test_convert_image_dtype_int_to_int_consistency(self):
602+
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
603+
input_max = torch.iinfo(input_dtype).max
604+
input_image = torch.tensor((0, input_max), dtype=input_dtype)
605+
for output_dtype in output_dtypes:
606+
output_max = torch.iinfo(output_dtype).max
607+
if output_max <= input_max:
608+
continue
609+
610+
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
611+
transform = transforms.ConvertImageDtype(output_dtype)
612+
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
613+
output_image = inverse_transfrom(transform(input_image))
614+
615+
actual_min, actual_max = output_image.tolist()
616+
desired_min, desired_max = 0, input_max
617+
618+
self.assertEqual(actual_min, desired_min)
619+
self.assertEqual(actual_max, desired_max)
561620

562621
@unittest.skipIf(accimage is None, 'accimage not available')
563622
def test_accimage_to_tensor(self):

torchvision/transforms/functional.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ def pil_to_tensor(pic):
113113
return img
114114

115115

116-
def convert_image_dtype(
117-
image: torch.Tensor, dtype: torch.dtype = torch.float
118-
) -> torch.Tensor:
116+
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
119117
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
120118
121119
Args:
@@ -125,28 +123,42 @@ def convert_image_dtype(
125123
Returns:
126124
(torch.Tensor): Converted image
127125
126+
.. note::
127+
128+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
129+
If converted back and forth, this mismatch has no effect.
130+
128131
Raises:
129-
TypeError: When trying to cast :class:`torch.float32` to :class:`torch.int32`
130-
or :class:`torch.int64` as well as for trying to cast
131-
:class:`torch.float64` to :class:`torch.int64`. These conversions are
132-
unsafe since the floating point ``dtype`` cannot store consecutive XXX. which might lead to overflow errors
132+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
133+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
134+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
135+
of the integer ``dtype``.
133136
"""
134-
def float_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
135-
return image.to(dtype)
136-
137-
def float_to_int(image: torch.Tensor, dtype: torch.dtype, eps=1e-3) -> torch.Tensor:
138-
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (image.dtype == torch.float64 and dtype == torch.int64):
139-
msg = (f"The cast from {image.dtype} to {dtype} cannot be performed safely, "
140-
f"since {image.dtype} cannot ")
141-
raise TypeError(msg)
142-
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
137+
if image.dtype == dtype:
138+
return image
143139

144-
def int_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
145-
max = torch.iinfo(image.dtype).max
146-
image = image.to(dtype)
147-
return image / max
140+
if image.dtype.is_floating_point:
141+
# float to float
142+
if dtype.is_floating_point:
143+
return image.to(dtype)
144+
145+
# float to int
146+
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
147+
image.dtype == torch.float64 and dtype == torch.int64
148+
):
149+
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
150+
raise RuntimeError(msg)
148151

149-
def int_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
152+
eps = 1e-3
153+
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
154+
else:
155+
# int to float
156+
if dtype.is_floating_point:
157+
max = torch.iinfo(image.dtype).max
158+
image = image.to(dtype)
159+
return image / max
160+
161+
# int to int
150162
input_max = torch.iinfo(image.dtype).max
151163
output_max = torch.iinfo(dtype).max
152164

@@ -157,21 +169,7 @@ def int_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
157169
else:
158170
factor = (output_max + 1) // (input_max + 1)
159171
image = image.to(dtype)
160-
return (image + 1) * factor - 1
161-
162-
if image.dtype == dtype:
163-
return image
164-
165-
if image.dtype.is_floating_point:
166-
if dtype.is_floating_point:
167-
return float_to_float(image, dtype)
168-
else:
169-
return float_to_int(image, dtype)
170-
else:
171-
if dtype.is_floating_point:
172-
return int_to_float(image, dtype)
173-
else:
174-
return int_to_int(image, dtype)
172+
return image * factor
175173

176174

177175
def to_pil_image(pic, mode=None):

torchvision/transforms/transforms.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,18 @@ class ConvertImageDtype(object):
121121
Args:
122122
dtype (torch.dtype): Desired data type of the output
123123
124+
.. note::
125+
126+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
127+
If converted back and forth, this mismatch has no effect.
128+
129+
Raises:
130+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
131+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
132+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
133+
of the integer ``dtype``.
124134
"""
135+
125136
def __init__(self, dtype: torch.dtype) -> None:
126137
self.dtype = dtype
127138

0 commit comments

Comments
 (0)