diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 83e74e3730e..80be193c13d 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -378,6 +378,28 @@ def test__transform(self, padding, fill, padding_mode, mocker): fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) + @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) + def test__transform_image_mask(self, fill, mocker): + transform = transforms.Pad(1, fill=fill, padding_mode="constant") + + fn = mocker.patch("torchvision.prototype.transforms.functional.pad") + image = features.Image(torch.rand(3, 32, 32)) + mask = features.Mask(torch.randint(0, 5, size=(32, 32))) + inpt = [image, mask] + _ = transform(inpt) + + if isinstance(fill, int): + calls = [ + mocker.call(image, padding=1, fill=fill, padding_mode="constant"), + mocker.call(mask, padding=1, fill=0, padding_mode="constant"), + ] + else: + calls = [ + mocker.call(image, padding=1, fill=fill[type(image)], padding_mode="constant"), + mocker.call(mask, padding=1, fill=fill[type(mask)], padding_mode="constant"), + ] + fn.assert_has_calls(calls) + class TestRandomZoomOut: def test_assertions(self): @@ -400,7 +422,6 @@ def test__get_params(self, fill, side_range, mocker): params = transform._get_params(image) - assert params["fill"] == fill assert len(params["padding"]) == 4 assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h @@ -426,7 +447,34 @@ def test__transform(self, fill, side_range, mocker): torch.rand(1) # random apply changes random state params = transform._get_params(inpt) - fn.assert_called_once_with(inpt, **params) + fn.assert_called_once_with(inpt, **params, fill=fill) + + @pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}]) + def test__transform_image_mask(self, fill, mocker): + transform = transforms.RandomZoomOut(fill=fill, p=1.0) + + fn = mocker.patch("torchvision.prototype.transforms.functional.pad") + image = features.Image(torch.rand(3, 32, 32)) + mask = features.Mask(torch.randint(0, 5, size=(32, 32))) + inpt = [image, mask] + + torch.manual_seed(12) + _ = transform(inpt) + torch.manual_seed(12) + torch.rand(1) # random apply changes random state + params = transform._get_params(inpt) + + if isinstance(fill, int): + calls = [ + mocker.call(image, **params, fill=fill), + mocker.call(mask, **params, fill=0), + ] + else: + calls = [ + mocker.call(image, **params, fill=fill[type(image)]), + mocker.call(mask, **params, fill=fill[type(mask)]), + ] + fn.assert_has_calls(calls) class TestRandomRotation: diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index b78eeb1cf12..022915798e1 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -58,7 +58,14 @@ def pad( if not isinstance(padding, int): padding = list(padding) - output = self._F.pad_mask(self, padding, padding_mode=padding_mode) + if isinstance(fill, (int, float)) or fill is None: + if fill is None: + fill = 0 + output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) + else: + # Let's raise an error for vector fill on masks + raise ValueError("Non-scalar fill value is not supported") + return Mask.new_like(self, output) def rotate( diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 265eed752a6..d351fd66f38 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,7 +1,8 @@ import math import numbers import warnings -from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union +from collections import defaultdict +from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image import torch @@ -16,6 +17,7 @@ DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] +FillType = Union[int, float, Sequence[int], Sequence[float]] class RandomHorizontalFlip(_RandomApplyTransform): @@ -196,9 +198,21 @@ def forward(self, *inputs: Any) -> Any: return super().forward(*inputs) -def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None: - if not isinstance(fill, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate fill arg") +def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: + if isinstance(fill, dict): + for key, value in fill.items(): + # Check key for type + _check_fill_arg(value) + else: + if not isinstance(fill, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate fill arg") + + +def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: + if isinstance(fill, dict): + return fill + else: + return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value] def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: @@ -220,7 +234,7 @@ class Pad(Transform): def __init__( self, padding: Union[int, Sequence[int]], - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() @@ -230,24 +244,25 @@ def __init__( _check_padding_mode_arg(padding_mode) self.padding = padding - self.fill = fill + self.fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode) + fill = self.fill[type(inpt)] + return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) class RandomZoomOut(_RandomApplyTransform): def __init__( self, - fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + fill: Union[FillType, Dict[Type, FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: super().__init__(p=p) _check_fill_arg(fill) - self.fill = fill + self.fill = _setup_fill_arg(fill) _check_sequence_input(side_range, "side_range", req_sizes=(2,)) @@ -256,7 +271,7 @@ def __init__( raise ValueError(f"Invalid canvas side range provided {side_range}.") def _get_params(self, sample: Any) -> Dict[str, Any]: - orig_c, orig_h, orig_w = query_chw(sample) + _, orig_h, orig_w = query_chw(sample) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) @@ -269,10 +284,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: bottom = canvas_height - (top + orig_h) padding = [left, top, right, bottom] - return dict(padding=padding, fill=self.fill) + return dict(padding=padding) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.pad(inpt, **params) + fill = self.fill[type(inpt)] + return F.pad(inpt, **params, fill=fill) class RandomRotation(Transform): diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 554236d27dc..a0ed43056ea 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -635,14 +635,19 @@ def _pad_with_vector_fill( return output -def pad_mask(mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant") -> torch.Tensor: +def pad_mask( + mask: torch.Tensor, + padding: Union[int, List[int]], + padding_mode: str = "constant", + fill: Optional[Union[int, float]] = 0, +) -> torch.Tensor: if mask.ndim < 3: mask = mask.unsqueeze(0) needs_squeeze = True else: needs_squeeze = False - output = pad_image_tensor(img=mask, padding=padding, fill=0, padding_mode=padding_mode) + output = pad_image_tensor(img=mask, padding=padding, fill=fill, padding_mode=padding_mode) if needs_squeeze: output = output.squeeze(0)