File tree Expand file tree Collapse file tree 1 file changed +12
-5
lines changed Expand file tree Collapse file tree 1 file changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -124,15 +124,22 @@ def convert_image_dtype(
124
124
125
125
Returns:
126
126
(torch.Tensor): Converted image
127
+
128
+ Raises:
129
+ TypeError: When trying to cast :class:`torch.float32` to :class:`torch.int32`
130
+ or :class:`torch.int64` as well as for trying to cast
131
+ :class:`torch.float64` to :class:`torch.int64`. These conversions are
132
+ unsafe since the floating point ``dtype`` cannot store consecutive XXX. which might lead to overflow errors
127
133
"""
128
134
def float_to_float (image : torch .Tensor , dtype : torch .dtype ) -> torch .Tensor :
129
135
return image .to (dtype )
130
136
131
- def float_to_int (image : torch .Tensor , dtype : torch .dtype ) -> torch .Tensor :
132
- max = float (torch .iinfo (dtype ).max )
133
- image = image * (max + 1.0 )
134
- image = torch .clamp (image , max )
135
- return image .to (dtype )
137
+ def float_to_int (image : torch .Tensor , dtype : torch .dtype , eps = 1e-3 ) -> torch .Tensor :
138
+ if (image .dtype == torch .float32 and dtype in (torch .int32 , torch .int64 )) or (image .dtype == torch .float64 and dtype == torch .int64 ):
139
+ msg = (f"The cast from { image .dtype } to { dtype } cannot be performed safely, "
140
+ f"since { image .dtype } cannot " )
141
+ raise TypeError (msg )
142
+ return image .mul (torch .iinfo (dtype ).max + 1 - eps ).to (dtype )
136
143
137
144
def int_to_float (image : torch .Tensor , dtype : torch .dtype ) -> torch .Tensor :
138
145
max = torch .iinfo (image .dtype ).max
You can’t perform that action at this time.
0 commit comments