@@ -232,7 +232,7 @@ def affine_image_tensor(
232
232
scale : float ,
233
233
shear : List [float ],
234
234
interpolation : InterpolationMode = InterpolationMode .NEAREST ,
235
- fill : Optional [List [float ]] = None ,
235
+ fill : Optional [Union [ int , float , List [float ] ]] = None ,
236
236
center : Optional [List [float ]] = None ,
237
237
) -> torch .Tensor :
238
238
if img .numel () == 0 :
@@ -405,7 +405,9 @@ def affine_mask(
405
405
return output
406
406
407
407
408
- def _convert_fill_arg (fill : Optional [Union [int , float , Sequence [int ], Sequence [float ]]]) -> Optional [List [float ]]:
408
+ def _convert_fill_arg (
409
+ fill : Optional [Union [int , float , Sequence [int ], Sequence [float ]]]
410
+ ) -> Optional [Union [int , float , List [float ]]]:
409
411
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
410
412
# So, we can't reassign fill to 0
411
413
# if fill is None:
@@ -416,9 +418,6 @@ def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[f
416
418
# This cast does Sequence -> List[float] to please mypy and torch.jit.script
417
419
if not isinstance (fill , (int , float )):
418
420
fill = [float (v ) for v in list (fill )]
419
- else :
420
- # It is OK to cast int to float as later we use inpt.dtype
421
- fill = [float (fill )]
422
421
return fill
423
422
424
423
@@ -591,7 +590,23 @@ def rotate(
591
590
def pad_image_tensor (
592
591
img : torch .Tensor ,
593
592
padding : Union [int , List [int ]],
594
- fill : Optional [Union [int , float ]] = 0 ,
593
+ fill : Optional [Union [int , float , List [float ]]] = None ,
594
+ padding_mode : str = "constant" ,
595
+ ) -> torch .Tensor :
596
+ if fill is None :
597
+ # This is a JIT workaround
598
+ return _pad_with_scalar_fill (img , padding , fill = None , padding_mode = padding_mode )
599
+ elif isinstance (fill , (int , float )) or len (fill ) == 1 :
600
+ fill_number = fill [0 ] if isinstance (fill , list ) else fill
601
+ return _pad_with_scalar_fill (img , padding , fill = fill_number , padding_mode = padding_mode )
602
+ else :
603
+ return _pad_with_vector_fill (img , padding , fill = fill , padding_mode = padding_mode )
604
+
605
+
606
+ def _pad_with_scalar_fill (
607
+ img : torch .Tensor ,
608
+ padding : Union [int , List [int ]],
609
+ fill : Union [int , float , None ],
595
610
padding_mode : str = "constant" ,
596
611
) -> torch .Tensor :
597
612
num_channels , height , width = img .shape [- 3 :]
@@ -614,13 +629,13 @@ def pad_image_tensor(
614
629
def _pad_with_vector_fill (
615
630
img : torch .Tensor ,
616
631
padding : Union [int , List [int ]],
617
- fill : Sequence [float ] = [ 0.0 ],
632
+ fill : List [float ],
618
633
padding_mode : str = "constant" ,
619
634
) -> torch .Tensor :
620
635
if padding_mode != "constant" :
621
636
raise ValueError (f"Padding mode '{ padding_mode } ' is not supported if fill is not scalar" )
622
637
623
- output = pad_image_tensor (img , padding , fill = 0 , padding_mode = "constant" )
638
+ output = _pad_with_scalar_fill (img , padding , fill = 0 , padding_mode = "constant" )
624
639
left , right , top , bottom = _parse_pad_padding (padding )
625
640
fill = torch .tensor (fill , dtype = img .dtype , device = img .device ).view (- 1 , 1 , 1 )
626
641
@@ -639,8 +654,14 @@ def pad_mask(
639
654
mask : torch .Tensor ,
640
655
padding : Union [int , List [int ]],
641
656
padding_mode : str = "constant" ,
642
- fill : Optional [Union [int , float ]] = 0 ,
657
+ fill : Optional [Union [int , float , List [ float ]]] = None ,
643
658
) -> torch .Tensor :
659
+ if fill is None :
660
+ fill = 0
661
+
662
+ if isinstance (fill , list ):
663
+ raise ValueError ("Non-scalar fill value is not supported" )
664
+
644
665
if mask .ndim < 3 :
645
666
mask = mask .unsqueeze (0 )
646
667
needs_squeeze = True
@@ -693,10 +714,9 @@ def pad(
693
714
if not isinstance (padding , int ):
694
715
padding = list (padding )
695
716
696
- # TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
697
- if isinstance (fill , (int , float )) or fill is None :
698
- return pad_image_tensor (inpt , padding , fill = fill , padding_mode = padding_mode )
699
- return _pad_with_vector_fill (inpt , padding , fill = fill , padding_mode = padding_mode )
717
+ fill = _convert_fill_arg (fill )
718
+
719
+ return pad_image_tensor (inpt , padding , fill = fill , padding_mode = padding_mode )
700
720
701
721
702
722
crop_image_tensor = _FT .crop
@@ -739,7 +759,7 @@ def perspective_image_tensor(
739
759
img : torch .Tensor ,
740
760
perspective_coeffs : List [float ],
741
761
interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
742
- fill : Optional [List [float ]] = None ,
762
+ fill : Optional [Union [ int , float , List [float ] ]] = None ,
743
763
) -> torch .Tensor :
744
764
return _FT .perspective (img , perspective_coeffs , interpolation = interpolation .value , fill = fill )
745
765
@@ -878,7 +898,7 @@ def elastic_image_tensor(
878
898
img : torch .Tensor ,
879
899
displacement : torch .Tensor ,
880
900
interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
881
- fill : Optional [List [float ]] = None ,
901
+ fill : Optional [Union [ int , float , List [float ] ]] = None ,
882
902
) -> torch .Tensor :
883
903
return _FT .elastic_transform (img , displacement , interpolation = interpolation .value , fill = fill )
884
904
0 commit comments