From d6ba54323148b1b0412c5620273b4c84458a5b13 Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Wed, 16 Mar 2022 23:11:03 +0100 Subject: [PATCH] refactor: port RandomVerticalFlip to prototype API (#5524) --- test/test_prototype_transforms.py | 53 +++++++++++++++++++ torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/_geometry.py | 30 +++++++++++ .../transforms/functional/__init__.py | 2 + .../transforms/functional/_geometry.py | 20 +++++++ 5 files changed, 106 insertions(+) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b6085bb1c71..dc3de480d1f 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -243,3 +243,56 @@ def test_features_bounding_box(self, p): assert_equal(expected, actual) assert actual.format == expected.format assert actual.image_size == expected.image_size + + +@pytest.mark.parametrize("p", [0.0, 1.0]) +class TestRandomVerticalFlip: + def input_expected_image_tensor(self, p, dtype=torch.float32): + input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype) + expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype) + + return input, expected if p == 1 else input + + def test_simple_tensor(self, p): + input, expected = self.input_expected_image_tensor(p) + transform = transforms.RandomVerticalFlip(p=p) + + actual = transform(input) + + assert_equal(expected, actual) + + def test_pil_image(self, p): + input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) + transform = transforms.RandomVerticalFlip(p=p) + + actual = transform(to_pil_image(input)) + + assert_equal(expected, pil_to_tensor(actual)) + + def test_features_image(self, p): + input, expected = self.input_expected_image_tensor(p) + transform = transforms.RandomVerticalFlip(p=p) + + actual = transform(features.Image(input)) + + assert_equal(features.Image(expected), actual) + + def test_features_segmentation_mask(self, p): + input, expected = self.input_expected_image_tensor(p) + transform = transforms.RandomVerticalFlip(p=p) + + actual = transform(features.SegmentationMask(input)) + + assert_equal(features.SegmentationMask(expected), actual) + + def test_features_bounding_box(self, p): + input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) + transform = transforms.RandomVerticalFlip(p=p) + + actual = transform(input) + + expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input + expected = features.BoundingBox.new_like(input, data=expected_image_tensor) + assert_equal(expected, actual) + assert actual.format == expected.format + assert actual.image_size == expected.image_size diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 1cb53f2a885..81e914e8383 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -15,6 +15,7 @@ TenCrop, BatchMultiCrop, RandomHorizontalFlip, + RandomVerticalFlip, Pad, RandomZoomOut, ) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 061f12cd446..ec19c7afc72 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -45,6 +45,36 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return input +class RandomVerticalFlip(Transform): + def __init__(self, p: float = 0.5) -> None: + super().__init__() + self.p = p + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if torch.rand(1) > self.p: + return sample + + return super().forward(sample) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, features.Image): + output = F.vertical_flip_image_tensor(input) + return features.Image.new_like(input, output) + elif isinstance(input, features.SegmentationMask): + output = F.vertical_flip_segmentation_mask(input) + return features.SegmentationMask.new_like(input, output) + elif isinstance(input, features.BoundingBox): + output = F.vertical_flip_bounding_box(input, format=input.format, image_size=input.image_size) + return features.BoundingBox.new_like(input, output) + elif isinstance(input, PIL.Image.Image): + return F.vertical_flip_image_pil(input) + elif is_simple_tensor(input): + return F.vertical_flip_image_tensor(input) + else: + return input + + class Resize(Transform): def __init__( self, diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 6a317e87182..469768ba9c2 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -63,6 +63,8 @@ perspective_image_pil, vertical_flip_image_tensor, vertical_flip_image_pil, + vertical_flip_bounding_box, + vertical_flip_segmentation_mask, five_crop_image_tensor, five_crop_image_pil, ten_crop_image_tensor, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index b35db660438..8718c381525 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -81,6 +81,26 @@ def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: vertical_flip_image_pil = _FP.vflip +def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: + return vertical_flip_image_tensor(segmentation_mask) + + +def vertical_flip_bounding_box( + bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_box.shape + + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]] + + return convert_bounding_box_format( + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(shape) + + def _affine_parse_args( angle: float, translate: List[float],