Skip to content

Commit adfb096

Browse files
committed
dirty progress
1 parent b6becf1 commit adfb096

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

torchvision/transforms/functional.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,22 @@ def convert_image_dtype(
124124
125125
Returns:
126126
(torch.Tensor): Converted image
127+
128+
Raises:
129+
TypeError: When trying to cast :class:`torch.float32` to :class:`torch.int32`
130+
or :class:`torch.int64` as well as for trying to cast
131+
:class:`torch.float64` to :class:`torch.int64`. These conversions are
132+
unsafe since the floating point ``dtype`` cannot store consecutive XXX. which might lead to overflow errors
127133
"""
128134
def float_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
129135
return image.to(dtype)
130136

131-
def float_to_int(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
132-
max = float(torch.iinfo(dtype).max)
133-
image = image * (max + 1.0)
134-
image = torch.clamp(image, max)
135-
return image.to(dtype)
137+
def float_to_int(image: torch.Tensor, dtype: torch.dtype, eps=1e-3) -> torch.Tensor:
138+
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (image.dtype == torch.float64 and dtype == torch.int64):
139+
msg = (f"The cast from {image.dtype} to {dtype} cannot be performed safely, "
140+
f"since {image.dtype} cannot ")
141+
raise TypeError(msg)
142+
return image.mul(torch.iinfo(dtype).max + 1 - eps).to(dtype)
136143

137144
def int_to_float(image: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
138145
max = torch.iinfo(image.dtype).max

0 commit comments

Comments
 (0)