Skip to content

Commit 93e1cfb

Browse files
committed
Added RandomCrop transform and tests
1 parent df6918c commit 93e1cfb

File tree

4 files changed

+180
-13
lines changed

4 files changed

+180
-13
lines changed

test/test_prototype_transforms.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class TestSmoke:
8282
transforms.RandomZoomOut(),
8383
transforms.RandomRotation(degrees=(-45, 45)),
8484
transforms.RandomAffine(degrees=(-45, 45)),
85+
transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True),
8586
)
8687
def test_common(self, transform, input):
8788
transform(input)
@@ -566,3 +567,80 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
566567
params = transform._get_params(inpt)
567568

568569
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
570+
571+
572+
class TestRandomCrop:
573+
def test_assertions(self):
574+
with pytest.raises(ValueError, match="Please provide only two dimensions"):
575+
transforms.RandomCrop([10, 12, 14])
576+
577+
with pytest.raises(TypeError, match="Got inappropriate padding arg"):
578+
transforms.RandomCrop([10, 12], padding="abc")
579+
580+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
581+
transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7])
582+
583+
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
584+
transforms.RandomCrop([10, 12], padding=1, fill="abc")
585+
586+
with pytest.raises(ValueError, match="Padding mode should be either"):
587+
transforms.RandomCrop([10, 12], padding=1, padding_mode="abc")
588+
589+
def test__get_params(self):
590+
image = features.Image(torch.rand(1, 3, 32, 32))
591+
h, w = image.shape[-2:]
592+
593+
transform = transforms.RandomCrop([10, 10])
594+
params = transform._get_params(image)
595+
596+
assert 0 <= params["top"] <= h - transform.size[0] + 1
597+
assert 0 <= params["left"] <= w - transform.size[1] + 1
598+
assert params["height"] == 10
599+
assert params["width"] == 10
600+
601+
@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
602+
@pytest.mark.parametrize("pad_if_needed", [False, True])
603+
@pytest.mark.parametrize("fill", [False, True])
604+
@pytest.mark.parametrize("padding_mode", ["constant", "edge"])
605+
def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):
606+
output_size = [10, 12]
607+
transform = transforms.RandomCrop(
608+
output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode
609+
)
610+
611+
inpt = features.Image(torch.rand(1, 3, 32, 32))
612+
expected = mocker.MagicMock(spec=features.Image)
613+
expected.num_channels = 3
614+
if isinstance(padding, int):
615+
expected.image_size = (inpt.image_size[0] + padding, inpt.image_size[1] + padding)
616+
elif isinstance(padding, list):
617+
expected.image_size = (
618+
inpt.image_size[0] + sum(padding[0::2]),
619+
inpt.image_size[1] + sum(padding[1::2]),
620+
)
621+
else:
622+
expected.image_size = inpt.image_size
623+
_ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected)
624+
fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop")
625+
626+
# vfdev-5, Feature Request: let's store params as Transform attribute
627+
# This could be also helpful for users
628+
torch.manual_seed(12)
629+
_ = transform(inpt)
630+
torch.manual_seed(12)
631+
if padding is None and not pad_if_needed:
632+
params = transform._get_params(inpt)
633+
fn_crop.assert_called_once_with(
634+
inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
635+
)
636+
elif not pad_if_needed:
637+
params = transform._get_params(expected)
638+
fn_crop.assert_called_once_with(
639+
expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
640+
)
641+
elif padding is None:
642+
# vfdev-5: I do not know how to mock and test this case
643+
pass
644+
else:
645+
# vfdev-5: I do not know how to mock and test this case
646+
pass

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Resize,
1111
CenterCrop,
1212
RandomResizedCrop,
13+
RandomCrop,
1314
FiveCrop,
1415
TenCrop,
1516
BatchMultiCrop,

torchvision/prototype/transforms/_geometry.py

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def __init__(
3535
antialias: Optional[bool] = None,
3636
) -> None:
3737
super().__init__()
38-
self.size = [size] if isinstance(size, int) else list(size)
38+
39+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
3940
self.interpolation = interpolation
4041
self.max_size = max_size
4142
self.antialias = antialias
@@ -80,7 +81,6 @@ def __init__(
8081
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
8182
warnings.warn("Scale and ratio should be of kind (min, max)")
8283

83-
self.size = size
8484
self.scale = scale
8585
self.ratio = ratio
8686
self.interpolation = interpolation
@@ -225,6 +225,19 @@ def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) ->
225225
raise TypeError("Got inappropriate fill arg")
226226

227227

228+
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
229+
if not isinstance(padding, (numbers.Number, tuple, list)):
230+
raise TypeError("Got inappropriate padding arg")
231+
232+
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
233+
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
234+
235+
236+
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
237+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
238+
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
239+
240+
228241
class Pad(Transform):
229242
def __init__(
230243
self,
@@ -233,18 +246,10 @@ def __init__(
233246
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
234247
) -> None:
235248
super().__init__()
236-
if not isinstance(padding, (numbers.Number, tuple, list)):
237-
raise TypeError("Got inappropriate padding arg")
238-
239-
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
240-
raise ValueError(
241-
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
242-
)
243249

250+
_check_padding_arg(padding)
244251
_check_fill_arg(fill)
245-
246-
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
247-
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
252+
_check_padding_mode_arg(padding_mode)
248253

249254
self.padding = padding
250255
self.fill = fill
@@ -416,3 +421,75 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
416421
fill=self.fill,
417422
center=self.center,
418423
)
424+
425+
426+
class RandomCrop(Transform):
427+
def __init__(
428+
self,
429+
size: Union[int, Sequence[int]],
430+
padding: Optional[Union[int, Sequence[int]]] = None,
431+
pad_if_needed: bool = False,
432+
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
433+
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
434+
) -> None:
435+
super().__init__()
436+
437+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
438+
439+
if padding is not None:
440+
_check_padding_arg(padding)
441+
442+
if (padding is not None) or pad_if_needed:
443+
_check_padding_mode_arg(padding_mode)
444+
_check_fill_arg(fill)
445+
446+
self.padding = padding
447+
self.pad_if_needed = pad_if_needed
448+
self.fill = fill
449+
self.padding_mode = padding_mode
450+
451+
def _get_params(self, sample: Any) -> Dict[str, Any]:
452+
image = query_image(sample)
453+
_, height, width = get_image_dimensions(image)
454+
output_height, output_width = self.size
455+
456+
if height + 1 < output_height or width + 1 < output_width:
457+
raise ValueError(
458+
f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}"
459+
)
460+
461+
if width == output_width and height == output_height:
462+
return dict(top=0, left=0, height=height, width=width)
463+
464+
top = torch.randint(0, height - output_height + 1, size=(1,)).item()
465+
left = torch.randint(0, width - output_width + 1, size=(1,)).item()
466+
return dict(top=top, left=left, height=output_height, width=output_width)
467+
468+
def _forward(self, flat_inputs: List[Any]) -> List[Any]:
469+
if self.padding is not None:
470+
flat_inputs = [F.pad(flat_input, self.padding, self.fill, self.padding_mode) for flat_input in flat_inputs]
471+
472+
image = query_image(flat_inputs)
473+
_, height, width = get_image_dimensions(image)
474+
475+
# pad the width if needed
476+
if self.pad_if_needed and width < self.size[1]:
477+
padding = [self.size[1] - width, 0]
478+
flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs]
479+
# pad the height if needed
480+
if self.pad_if_needed and height < self.size[0]:
481+
padding = [0, self.size[0] - height]
482+
flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs]
483+
484+
params = self._get_params(flat_inputs)
485+
486+
return [F.crop(flat_input, **params) for flat_input in flat_inputs]
487+
488+
def forward(self, *inputs: Any) -> Any:
489+
from torch.utils._pytree import tree_flatten, tree_unflatten
490+
491+
sample = inputs if len(inputs) > 1 else inputs[0]
492+
493+
flat_inputs, spec = tree_flatten(sample)
494+
out_flat_inputs = self._forward(flat_inputs)
495+
return tree_unflatten(out_flat_inputs, spec)

torchvision/prototype/transforms/_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,28 @@
22

33
import PIL.Image
44
import torch
5+
from torch.utils._pytree import tree_flatten
56
from torchvision.prototype import features
67
from torchvision.prototype.utils._internal import query_recursively
78

89
from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil
910

1011

1112
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
13+
flat_sample, _ = tree_flatten(sample)
14+
for i in flat_sample:
15+
if type(i) == torch.Tensor or isinstance(i, (PIL.Image.Image, features.Image)):
16+
return i
17+
18+
raise TypeError("No image was found in the sample")
19+
20+
21+
# vfdev-5: let's use tree_flatten instead of query_recursively and internal fn to make the code simplier
22+
def query_image_(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
1223
def fn(
1324
id: Tuple[Any, ...], input: Any
1425
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
15-
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image):
26+
if type(input) == torch.Tensor or isinstance(input, (PIL.Image.Image, features.Image)):
1627
return id, input
1728

1829
return None

0 commit comments

Comments
 (0)