@@ -552,24 +552,25 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine):
552
552
def _test_affine_all_ops (self , tensor , pil_img , scripted_affine ):
553
553
# 4) Test rotation + translation + scale + share
554
554
test_configs = [
555
- (45 , [5 , 6 ], 1.0 , [0.0 , 0.0 ]),
556
- (33 , (5 , - 4 ), 1.0 , [0.0 , 0.0 ]),
557
- (45 , [- 5 , 4 ], 1.2 , [0.0 , 0.0 ]),
558
- (33 , (- 4 , - 8 ), 2.0 , [0.0 , 0.0 ]),
559
- (85 , (10 , - 10 ), 0.7 , [0.0 , 0.0 ]),
560
- (0 , [0 , 0 ], 1.0 , [35.0 , ]),
561
- (- 25 , [0 , 0 ], 1.2 , [0.0 , 15.0 ]),
562
- (- 45 , [- 10 , 0 ], 0.7 , [2.0 , 5.0 ]),
563
- (- 45 , [- 10 , - 10 ], 1.2 , [4.0 , 5.0 ]),
564
- (- 90 , [0 , 0 ], 1.0 , [0.0 , 0.0 ]),
555
+ (45.5 , [5 , 6 ], 1.0 , [0.0 , 0.0 ], None ),
556
+ (33 , (5 , - 4 ), 1.0 , [0.0 , 0.0 ], [ 0 , 0 , 0 ] ),
557
+ (45 , [- 5 , 4 ], 1.2 , [0.0 , 0.0 ], ( 1 , 2 , 3 ) ),
558
+ (33 , (- 4 , - 8 ), 2.0 , [0.0 , 0.0 ], [ 255 , 255 , 255 ] ),
559
+ (85 , (10 , - 10 ), 0.7 , [0.0 , 0.0 ], [ 1 , ] ),
560
+ (0 , [0 , 0 ], 1.0 , [35.0 , ], ( 2.0 , ) ),
561
+ (- 25 , [0 , 0 ], 1.2 , [0.0 , 15.0 ], None ),
562
+ (- 45 , [- 10 , 0 ], 0.7 , [2.0 , 5.0 ], None ),
563
+ (- 45 , [- 10 , - 10 ], 1.2 , [4.0 , 5.0 ], None ),
564
+ (- 90 , [0 , 0 ], 1.0 , [0.0 , 0.0 ], None ),
565
565
]
566
566
for r in [NEAREST , ]:
567
- for a , t , s , sh in test_configs :
568
- out_pil_img = F .affine (pil_img , angle = a , translate = t , scale = s , shear = sh , interpolation = r )
567
+ for a , t , s , sh , f in test_configs :
568
+ f_pil = int (f [0 ]) if f is not None and len (f ) == 1 else f
569
+ out_pil_img = F .affine (pil_img , angle = a , translate = t , scale = s , shear = sh , interpolation = r , fill = f_pil )
569
570
out_pil_tensor = torch .from_numpy (np .array (out_pil_img ).transpose ((2 , 0 , 1 )))
570
571
571
572
for fn in [F .affine , scripted_affine ]:
572
- out_tensor = fn (tensor , angle = a , translate = t , scale = s , shear = sh , interpolation = r ).cpu ()
573
+ out_tensor = fn (tensor , angle = a , translate = t , scale = s , shear = sh , interpolation = r , fill = f ).cpu ()
573
574
574
575
if out_tensor .dtype != torch .uint8 :
575
576
out_tensor = out_tensor .to (torch .uint8 )
@@ -582,7 +583,7 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
582
583
ratio_diff_pixels ,
583
584
tol ,
584
585
msg = "{}: {}\n {} vs \n {}" .format (
585
- (r , a , t , s , sh ), ratio_diff_pixels , out_tensor [0 , :7 , :7 ], out_pil_tensor [0 , :7 , :7 ]
586
+ (r , a , t , s , sh , f ), ratio_diff_pixels , out_tensor [0 , :7 , :7 ], out_pil_tensor [0 , :7 , :7 ]
586
587
)
587
588
)
588
589
@@ -643,35 +644,36 @@ def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
643
644
for a in range (- 180 , 180 , 17 ):
644
645
for e in [True , False ]:
645
646
for c in centers :
646
-
647
- out_pil_img = F .rotate (pil_img , angle = a , interpolation = r , expand = e , center = c )
648
- out_pil_tensor = torch .from_numpy (np .array (out_pil_img ).transpose ((2 , 0 , 1 )))
649
- for fn in [F .rotate , scripted_rotate ]:
650
- out_tensor = fn (tensor , angle = a , interpolation = r , expand = e , center = c ).cpu ()
651
-
652
- if out_tensor .dtype != torch .uint8 :
653
- out_tensor = out_tensor .to (torch .uint8 )
654
-
655
- self .assertEqual (
656
- out_tensor .shape ,
657
- out_pil_tensor .shape ,
658
- msg = "{}: {} vs {}" .format (
659
- (img_size , r , dt , a , e , c ), out_tensor .shape , out_pil_tensor .shape
660
- )
661
- )
662
- num_diff_pixels = (out_tensor != out_pil_tensor ).sum ().item () / 3.0
663
- ratio_diff_pixels = num_diff_pixels / out_tensor .shape [- 1 ] / out_tensor .shape [- 2 ]
664
- # Tolerance : less than 3% of different pixels
665
- self .assertLess (
666
- ratio_diff_pixels ,
667
- 0.03 ,
668
- msg = "{}: {}\n {} vs \n {}" .format (
669
- (img_size , r , dt , a , e , c ),
647
+ for f in [None , [0 , 0 , 0 ], (1 , 2 , 3 ), [255 , 255 , 255 ], [1 , ], (2.0 , )]:
648
+ f_pil = int (f [0 ]) if f is not None and len (f ) == 1 else f
649
+ out_pil_img = F .rotate (pil_img , angle = a , interpolation = r , expand = e , center = c , fill = f_pil )
650
+ out_pil_tensor = torch .from_numpy (np .array (out_pil_img ).transpose ((2 , 0 , 1 )))
651
+ for fn in [F .rotate , scripted_rotate ]:
652
+ out_tensor = fn (tensor , angle = a , interpolation = r , expand = e , center = c , fill = f ).cpu ()
653
+
654
+ if out_tensor .dtype != torch .uint8 :
655
+ out_tensor = out_tensor .to (torch .uint8 )
656
+
657
+ self .assertEqual (
658
+ out_tensor .shape ,
659
+ out_pil_tensor .shape ,
660
+ msg = "{}: {} vs {}" .format (
661
+ (img_size , r , dt , a , e , c ), out_tensor .shape , out_pil_tensor .shape
662
+ ))
663
+
664
+ num_diff_pixels = (out_tensor != out_pil_tensor ).sum ().item () / 3.0
665
+ ratio_diff_pixels = num_diff_pixels / out_tensor .shape [- 1 ] / out_tensor .shape [- 2 ]
666
+ # Tolerance : less than 3% of different pixels
667
+ self .assertLess (
670
668
ratio_diff_pixels ,
671
- out_tensor [0 , :7 , :7 ],
672
- out_pil_tensor [0 , :7 , :7 ]
669
+ 0.03 ,
670
+ msg = "{}: {}\n {} vs \n {}" .format (
671
+ (img_size , r , dt , a , e , c , f ),
672
+ ratio_diff_pixels ,
673
+ out_tensor [0 , :7 , :7 ],
674
+ out_pil_tensor [0 , :7 , :7 ]
675
+ )
673
676
)
674
- )
675
677
676
678
def test_rotate (self ):
677
679
# Tests on square image
@@ -721,30 +723,33 @@ def test_rotate(self):
721
723
722
724
def _test_perspective (self , tensor , pil_img , scripted_transform , test_configs ):
723
725
dt = tensor .dtype
724
- for r in [NEAREST , ]:
725
- for spoints , epoints in test_configs :
726
- out_pil_img = F .perspective (pil_img , startpoints = spoints , endpoints = epoints , interpolation = r )
727
- out_pil_tensor = torch .from_numpy (np .array (out_pil_img ).transpose ((2 , 0 , 1 )))
726
+ for f in [None , [0 , 0 , 0 ], [1 , 2 , 3 ], [255 , 255 , 255 ], [1 , ], (2.0 , )]:
727
+ for r in [NEAREST , ]:
728
+ for spoints , epoints in test_configs :
729
+ f_pil = int (f [0 ]) if f is not None and len (f ) == 1 else f
730
+ out_pil_img = F .perspective (pil_img , startpoints = spoints , endpoints = epoints , interpolation = r ,
731
+ fill = f_pil )
732
+ out_pil_tensor = torch .from_numpy (np .array (out_pil_img ).transpose ((2 , 0 , 1 )))
728
733
729
- for fn in [F .perspective , scripted_transform ]:
730
- out_tensor = fn (tensor , startpoints = spoints , endpoints = epoints , interpolation = r ).cpu ()
734
+ for fn in [F .perspective , scripted_transform ]:
735
+ out_tensor = fn (tensor , startpoints = spoints , endpoints = epoints , interpolation = r , fill = f ).cpu ()
731
736
732
- if out_tensor .dtype != torch .uint8 :
733
- out_tensor = out_tensor .to (torch .uint8 )
737
+ if out_tensor .dtype != torch .uint8 :
738
+ out_tensor = out_tensor .to (torch .uint8 )
734
739
735
- num_diff_pixels = (out_tensor != out_pil_tensor ).sum ().item () / 3.0
736
- ratio_diff_pixels = num_diff_pixels / out_tensor .shape [- 1 ] / out_tensor .shape [- 2 ]
737
- # Tolerance : less than 5% of different pixels
738
- self .assertLess (
739
- ratio_diff_pixels ,
740
- 0.05 ,
741
- msg = "{}: {}\n {} vs \n {}" .format (
742
- (r , dt , spoints , epoints ),
740
+ num_diff_pixels = (out_tensor != out_pil_tensor ).sum ().item () / 3.0
741
+ ratio_diff_pixels = num_diff_pixels / out_tensor .shape [- 1 ] / out_tensor .shape [- 2 ]
742
+ # Tolerance : less than 5% of different pixels
743
+ self .assertLess (
743
744
ratio_diff_pixels ,
744
- out_tensor [0 , :7 , :7 ],
745
- out_pil_tensor [0 , :7 , :7 ]
745
+ 0.05 ,
746
+ msg = "{}: {}\n {} vs \n {}" .format (
747
+ (f , r , dt , spoints , epoints ),
748
+ ratio_diff_pixels ,
749
+ out_tensor [0 , :7 , :7 ],
750
+ out_pil_tensor [0 , :7 , :7 ]
751
+ )
746
752
)
747
- )
748
753
749
754
def test_perspective (self ):
750
755
0 commit comments