Skip to content

[proto] Argument fill can accept dict of base types #6586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,28 @@ def test__transform(self, padding, fill, padding_mode, mocker):

fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)

@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.Pad(1, fill=fill, padding_mode="constant")

fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
image = features.Image(torch.rand(3, 32, 32))
mask = features.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
_ = transform(inpt)

if isinstance(fill, int):
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=0, padding_mode="constant"),
]
else:
calls = [
mocker.call(image, padding=1, fill=fill[type(image)], padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill[type(mask)], padding_mode="constant"),
]
fn.assert_has_calls(calls)


class TestRandomZoomOut:
def test_assertions(self):
Expand All @@ -400,7 +422,6 @@ def test__get_params(self, fill, side_range, mocker):

params = transform._get_params(image)

assert params["fill"] == fill
assert len(params["padding"]) == 4
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h
Expand All @@ -426,7 +447,34 @@ def test__transform(self, fill, side_range, mocker):
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)

fn.assert_called_once_with(inpt, **params)
fn.assert_called_once_with(inpt, **params, fill=fill)

@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.RandomZoomOut(fill=fill, p=1.0)

fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
image = features.Image(torch.rand(3, 32, 32))
mask = features.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]

torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)

if isinstance(fill, int):
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=0),
]
else:
calls = [
mocker.call(image, **params, fill=fill[type(image)]),
mocker.call(mask, **params, fill=fill[type(mask)]),
]
fn.assert_has_calls(calls)


class TestRandomRotation:
Expand Down
9 changes: 8 additions & 1 deletion torchvision/prototype/features/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@ def pad(
if not isinstance(padding, int):
padding = list(padding)

output = self._F.pad_mask(self, padding, padding_mode=padding_mode)
if isinstance(fill, (int, float)) or fill is None:
if fill is None:
fill = 0
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
else:
# Let's raise an error for vector fill on masks
raise ValueError("Non-scalar fill value is not supported")

return Mask.new_like(self, output)

def rotate(
Expand Down
40 changes: 28 additions & 12 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
from collections import defaultdict
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union

import PIL.Image
import torch
Expand All @@ -16,6 +17,7 @@


DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]


class RandomHorizontalFlip(_RandomApplyTransform):
Expand Down Expand Up @@ -196,9 +198,21 @@ def forward(self, *inputs: Any) -> Any:
return super().forward(*inputs)


def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")


def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
if isinstance(fill, dict):
return fill
else:
return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value]


def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
Expand All @@ -220,7 +234,7 @@ class Pad(Transform):
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
Expand All @@ -230,24 +244,25 @@ def __init__(
_check_padding_mode_arg(padding_mode)

self.padding = padding
self.fill = fill
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)


class RandomZoomOut(_RandomApplyTransform):
def __init__(
self,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
super().__init__(p=p)

_check_fill_arg(fill)
self.fill = fill
self.fill = _setup_fill_arg(fill)

_check_sequence_input(side_range, "side_range", req_sizes=(2,))

Expand All @@ -256,7 +271,7 @@ def __init__(
raise ValueError(f"Invalid canvas side range provided {side_range}.")

def _get_params(self, sample: Any) -> Dict[str, Any]:
orig_c, orig_h, orig_w = query_chw(sample)
_, orig_h, orig_w = query_chw(sample)

r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
Expand All @@ -269,10 +284,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]

return dict(padding=padding, fill=self.fill)
return dict(padding=padding)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.pad(inpt, **params)
fill = self.fill[type(inpt)]
return F.pad(inpt, **params, fill=fill)


class RandomRotation(Transform):
Expand Down
9 changes: 7 additions & 2 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,14 +635,19 @@ def _pad_with_vector_fill(
return output


def pad_mask(mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant") -> torch.Tensor:
def pad_mask(
mask: torch.Tensor,
padding: Union[int, List[int]],
padding_mode: str = "constant",
fill: Optional[Union[int, float]] = 0,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False

output = pad_image_tensor(img=mask, padding=padding, fill=0, padding_mode=padding_mode)
output = pad_image_tensor(img=mask, padding=padding, fill=fill, padding_mode=padding_mode)

if needs_squeeze:
output = output.squeeze(0)
Expand Down