diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 81bae521b35..d76c90340a2 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -102,18 +102,20 @@ def affine_mask(): @register_kernel_info_from_sample_inputs_fn def rotate_image_tensor(): - for image, angle, expand, center, fill in itertools.product( + for image, angle, expand, center in itertools.product( make_images(), [-87, 15, 90], # angle [True, False], # expand [None, [12, 23]], # center - [None, [128], [12.0]], # fill ): if center is not None and expand: # Skip warning: The provided center argument is ignored if expand is True continue - yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=fill) + yield ArgsKwargs(image, angle=angle, expand=expand, center=center, fill=None) + + for fill in [None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]]: + yield ArgsKwargs(image, angle=23, expand=False, center=None, fill=fill) @register_kernel_info_from_sample_inputs_fn diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index a0ed43056ea..e7ca7463b79 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -467,7 +467,7 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[List[float]] = None, + fill: Optional[Union[int, float, List[float]]] = None, center: Optional[List[float]] = None, ) -> torch.Tensor: num_channels, height, width = img.shape[-3:] diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index bdc02ae6bcc..23d1a4b0edf 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -475,7 +475,7 @@ def _assert_grid_transform_inputs( img: Tensor, matrix: Optional[List[float]], interpolation: str, - fill: Optional[List[float]], + fill: Optional[Union[int, float, List[float]]], supported_interpolation_modes: List[str], coeffs: Optional[List[float]] = None, ) -> None: @@ -499,7 +499,7 @@ def _assert_grid_transform_inputs( # Check fill num_channels = get_dimensions(img)[0] - if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): + if fill is not None and isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): msg = ( "The number of elements in 'fill' cannot broadcast to match the number of " "channels of the image ({} != {})" @@ -539,7 +539,9 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp return img -def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor: +def _apply_grid_transform( + img: Tensor, grid: Tensor, mode: str, fill: Optional[Union[int, float, List[float]]] +) -> Tensor: img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype]) @@ -559,8 +561,8 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L mask = img[:, -1:, :, :] # N * 1 * H * W img = img[:, :-1, :, :] # N * C * H * W mask = mask.expand_as(img) - len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 - fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) + fill_list, len_fill = (fill, len(fill)) if isinstance(fill, (tuple, list)) else ([float(fill)], 1) + fill_img = torch.tensor(fill_list, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) if mode == "nearest": mask = mask < 0.5 img[mask] = fill_img[mask] @@ -648,7 +650,7 @@ def rotate( matrix: List[float], interpolation: str = "nearest", expand: bool = False, - fill: Optional[List[float]] = None, + fill: Optional[Union[int, float, List[float]]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) w, h = img.shape[-1], img.shape[-2]