9
9
10
10
import unittest
11
11
12
- from common_utils import TransformsTester , get_tmp_dir
12
+ from common_utils import TransformsTester , get_tmp_dir , int_dtypes , float_dtypes
13
13
14
14
15
15
class Tester (TransformsTester ):
@@ -27,26 +27,26 @@ def _test_functional_op(self, func, fn_kwargs):
27
27
transformed_pil_img = f (pil_img , ** fn_kwargs )
28
28
self .compareTensorToPIL (transformed_tensor , transformed_pil_img )
29
29
30
- def _test_transform_vs_scripted (self , transform , s_transform , tensor ):
30
+ def _test_transform_vs_scripted (self , transform , s_transform , tensor , msg = None ):
31
31
torch .manual_seed (12 )
32
32
out1 = transform (tensor )
33
33
torch .manual_seed (12 )
34
34
out2 = s_transform (tensor )
35
- self .assertTrue (out1 .equal (out2 ))
35
+ self .assertTrue (out1 .equal (out2 ), msg = msg )
36
36
37
- def _test_transform_vs_scripted_on_batch (self , transform , s_transform , batch_tensors ):
37
+ def _test_transform_vs_scripted_on_batch (self , transform , s_transform , batch_tensors , msg = None ):
38
38
torch .manual_seed (12 )
39
39
transformed_batch = transform (batch_tensors )
40
40
41
41
for i in range (len (batch_tensors )):
42
42
img_tensor = batch_tensors [i , ...]
43
43
torch .manual_seed (12 )
44
44
transformed_img = transform (img_tensor )
45
- self .assertTrue (transformed_img .equal (transformed_batch [i , ...]))
45
+ self .assertTrue (transformed_img .equal (transformed_batch [i , ...]), msg = msg )
46
46
47
47
torch .manual_seed (12 )
48
48
s_transformed_batch = s_transform (batch_tensors )
49
- self .assertTrue (transformed_batch .equal (s_transformed_batch ))
49
+ self .assertTrue (transformed_batch .equal (s_transformed_batch ), msg = msg )
50
50
51
51
def _test_class_op (self , method , meth_kwargs = None , test_exact_match = True , ** match_kwargs ):
52
52
if meth_kwargs is None :
@@ -492,6 +492,32 @@ def test_random_erasing(self):
492
492
self ._test_transform_vs_scripted (fn , scripted_fn , tensor )
493
493
self ._test_transform_vs_scripted_on_batch (fn , scripted_fn , batch_tensors )
494
494
495
+ def test_convert_image_dtype (self ):
496
+ tensor , _ = self ._create_data (26 , 34 , device = self .device )
497
+ batch_tensors = torch .rand (4 , 3 , 44 , 56 , device = self .device )
498
+
499
+ for in_dtype in int_dtypes () + float_dtypes ():
500
+ in_tensor = tensor .to (in_dtype )
501
+ in_batch_tensors = batch_tensors .to (in_dtype )
502
+ for out_dtype in int_dtypes () + float_dtypes ():
503
+
504
+ fn = T .ConvertImageDtype (dtype = out_dtype )
505
+ scripted_fn = torch .jit .script (fn )
506
+
507
+ if (in_dtype == torch .float32 and out_dtype in (torch .int32 , torch .int64 )) or \
508
+ (in_dtype == torch .float64 and out_dtype == torch .int64 ):
509
+ with self .assertRaisesRegex (RuntimeError , r"cannot be performed safely" ):
510
+ self ._test_transform_vs_scripted (fn , scripted_fn , in_tensor )
511
+ with self .assertRaisesRegex (RuntimeError , r"cannot be performed safely" ):
512
+ self ._test_transform_vs_scripted_on_batch (fn , scripted_fn , in_batch_tensors )
513
+ continue
514
+
515
+ self ._test_transform_vs_scripted (fn , scripted_fn , in_tensor )
516
+ self ._test_transform_vs_scripted_on_batch (fn , scripted_fn , in_batch_tensors )
517
+
518
+ with get_tmp_dir () as tmp_dir :
519
+ scripted_fn .save (os .path .join (tmp_dir , "t_convert_dtype.pt" ))
520
+
495
521
496
522
@unittest .skipIf (not torch .cuda .is_available (), reason = "Skip if no CUDA device" )
497
523
class CUDATester (Tester ):
0 commit comments