diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py new file mode 100644 index 00000000000..53776e1a8a4 --- /dev/null +++ b/test/test_prototype_transforms_functional.py @@ -0,0 +1,197 @@ +import functools +import itertools + +import pytest +import torch.testing +import torchvision.prototype.transforms.functional as F +from torch import jit +from torchvision.prototype import features + +make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") + + +def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32): + size = size or torch.randint(16, 33, (2,)).tolist() + + if isinstance(color_space, str): + color_space = features.ColorSpace[color_space] + num_channels = { + features.ColorSpace.GRAYSCALE: 1, + features.ColorSpace.RGB: 3, + }[color_space] + + shape = (*extra_dims, num_channels, *size) + if dtype.is_floating_point: + data = torch.rand(shape, dtype=dtype) + else: + data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype) + return features.Image(data, color_space=color_space) + + +make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE) +make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB) + + +def make_images( + sizes=((16, 16), (7, 33), (31, 9)), + color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB), + dtypes=(torch.float32, torch.uint8), + extra_dims=((4,), (2, 3)), +): + for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes): + yield make_image(size, color_space=color_space) + + for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims): + yield make_image(color_space=color_space, extra_dims=extra_dims_) + + +def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): + low, high = torch.broadcast_tensors( + *[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] + ) + try: + return torch.stack( + [ + torch.randint(low_scalar, high_scalar, (), **kwargs) + for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist()) + ] + ).reshape(low.shape) + except RuntimeError as error: + raise error + + +def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64): + if isinstance(format, str): + format = features.BoundingBoxFormat[format] + + height, width = image_size + + if format == features.BoundingBoxFormat.XYXY: + x1 = torch.randint(0, width // 2, extra_dims) + y1 = torch.randint(0, height // 2, extra_dims) + x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 + y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 + parts = (x1, y1, x2, y2) + elif format == features.BoundingBoxFormat.XYWH: + x = torch.randint(0, width // 2, extra_dims) + y = torch.randint(0, height // 2, extra_dims) + w = randint_with_tensor_bounds(1, width - x) + h = randint_with_tensor_bounds(1, height - y) + parts = (x, y, w, h) + elif format == features.BoundingBoxFormat.CXCYWH: + cx = torch.randint(1, width - 1, ()) + cy = torch.randint(1, height - 1, ()) + w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) + h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1) + parts = (cx, cy, w, h) + else: # format == features.BoundingBoxFormat._SENTINEL: + raise ValueError() + + return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size) + + +make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY) + + +def make_bounding_boxes( + formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH), + image_sizes=((32, 32),), + dtypes=(torch.int64, torch.float32), + extra_dims=((4,), (2, 3)), +): + for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes): + yield make_bounding_box(format=format, image_size=image_size, dtype=dtype) + + for format, extra_dims_ in itertools.product(formats, extra_dims): + yield make_bounding_box(format=format, extra_dims=extra_dims_) + + +class SampleInput: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +class KernelInfo: + def __init__(self, name, *, sample_inputs_fn): + self.name = name + self.kernel = getattr(F, name) + self._sample_inputs_fn = sample_inputs_fn + + def sample_inputs(self): + yield from self._sample_inputs_fn() + + def __call__(self, *args, **kwargs): + if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput): + sample_input = args[0] + return self.kernel(*sample_input.args, **sample_input.kwargs) + + return self.kernel(*args, **kwargs) + + +KERNEL_INFOS = [] + + +def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): + KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) + return sample_inputs_fn + + +@register_kernel_info_from_sample_inputs_fn +def horizontal_flip_image(): + for image in make_images(): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def horizontal_flip_bounding_box(): + for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): + yield SampleInput(bounding_box, image_size=bounding_box.image_size) + + +@register_kernel_info_from_sample_inputs_fn +def resize_image(): + for image, interpolation in itertools.product( + make_images(), + [ + F.InterpolationMode.BILINEAR, + F.InterpolationMode.NEAREST, + ], + ): + height, width = image.shape[-2:] + for size in [ + (height, width), + (int(height * 0.75), int(width * 1.25)), + ]: + yield SampleInput(image, size=size, interpolation=interpolation) + + +@register_kernel_info_from_sample_inputs_fn +def resize_bounding_box(): + for bounding_box in make_bounding_boxes(): + height, width = bounding_box.image_size + for new_image_size in [ + (height, width), + (int(height * 0.75), int(width * 1.25)), + ]: + yield SampleInput(bounding_box, old_image_size=bounding_box.image_size, new_image_size=new_image_size) + + +class TestKernelsCommon: + @pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name) + def test_scriptable(self, kernel_info): + jit.script(kernel_info.kernel) + + @pytest.mark.parametrize( + ("kernel_info", "sample_input"), + [ + pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}") + for kernel_info in KERNEL_INFOS + for idx, sample_input in enumerate(kernel_info.sample_inputs()) + ], + ) + def test_eager_vs_scripted(self, kernel_info, sample_input): + eager = kernel_info(sample_input) + scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs) + + torch.testing.assert_close(eager, scripted) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 2d0685c2088..6c5dac72d53 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -36,3 +36,14 @@ def __new__( bounding_box._metadata.update(dict(format=format, image_size=image_size)) return bounding_box + + def to_format(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": + # import at runtime to avoid cyclic imports + from torchvision.prototype.transforms.functional import convert_bounding_box_format + + if isinstance(format, str): + format = BoundingBoxFormat[format] + + return BoundingBox.new_like( + self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format + ) diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 9160b5e36e1..338b2d2230d 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -7,6 +7,7 @@ from torchvision.prototype.utils._internal import fromfile, ReadOnlyTensorBuffer from ._feature import Feature +from ._image import Image D = TypeVar("D", bound="EncodedData") @@ -37,6 +38,12 @@ def image_size(self) -> Tuple[int, int]: return self._image_size + def decode(self) -> Image: + # import at runtime to avoid cyclic imports + from torchvision.prototype.transforms.functional import decode_image_with_pil + + return Image(decode_image_with_pil(self)) + class EncodedVideo(EncodedData): pass diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 93a9b517235..a07da277314 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -4,8 +4,10 @@ import torch from torchvision.prototype.utils._internal import StrEnum from torchvision.transforms.functional import to_pil_image +from torchvision.utils import draw_bounding_boxes from torchvision.utils import make_grid +from ._bounding_box import BoundingBox from ._feature import Feature @@ -76,3 +78,6 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: def show(self) -> None: to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show() + + def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> "Image": + return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 56cca7b0402..963bdebc7ed 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,6 +1,8 @@ +from . import functional +from .functional import InterpolationMode # usort: skip + from ._transform import Transform from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip - from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop from ._misc import Identity, Normalize from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py new file mode 100644 index 00000000000..087f2fb2ac0 --- /dev/null +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -0,0 +1,27 @@ +from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label +from ._color import ( + adjust_brightness_image, + adjust_contrast_image, + adjust_saturation_image, + adjust_sharpness_image, + posterize_image, + solarize_image, + autocontrast_image, + equalize_image, + invert_image, +) +from ._geometry import ( + horizontal_flip_bounding_box, + horizontal_flip_image, + resize_bounding_box, + resize_image, + resize_segmentation_mask, + center_crop_image, + resized_crop_image, + InterpolationMode, + affine_image, + rotate_image, +) +from ._meta_conversion import convert_color_space, convert_bounding_box_format +from ._misc import normalize_image +from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py new file mode 100644 index 00000000000..814c34e5b00 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -0,0 +1,40 @@ +from typing import Tuple + +import torch +from torchvision.transforms import functional as _F + + +erase_image = _F.erase + + +def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: + if not inplace: + input = input.clone() + + input_rolled = input.roll(1, batch_dim) + return input.mul_(lam).add_(input_rolled.mul_(1 - lam)) + + +def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: + return _mixup(image_batch, -4, lam, inplace) + + +def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: + return _mixup(one_hot_label_batch, -2, lam, inplace) + + +def cutmix_image(image: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: + if not inplace: + image = image.clone() + + x1, y1, x2, y2 = box + image_rolled = image.roll(1, -4) + + image[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return image + + +def cutmix_one_hot_label( + one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False +) -> torch.Tensor: + return mixup_one_hot_label(one_hot_label_batch, lam=lam_adjusted, inplace=inplace) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py new file mode 100644 index 00000000000..f2529166d9a --- /dev/null +++ b/torchvision/prototype/transforms/functional/_color.py @@ -0,0 +1,20 @@ +from torchvision.transforms import functional as _F + + +adjust_brightness_image = _F.adjust_brightness + +adjust_saturation_image = _F.adjust_saturation + +adjust_contrast_image = _F.adjust_contrast + +adjust_sharpness_image = _F.adjust_sharpness + +posterize_image = _F.posterize + +solarize_image = _F.solarize + +autocontrast_image = _F.autocontrast + +equalize_image = _F.equalize + +invert_image = _F.invert diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py new file mode 100644 index 00000000000..c8142742fa8 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -0,0 +1,87 @@ +from typing import Tuple, List, Optional + +import torch +from torchvision.prototype.features import BoundingBoxFormat +from torchvision.transforms import ( # noqa: F401 + functional as _F, + InterpolationMode, +) + +from ._meta_conversion import convert_bounding_box_format + + +def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: + return image.flip((-1,)) + + +def horizontal_flip_bounding_box(bounding_box: torch.Tensor, *, image_size: Tuple[int, int]) -> torch.Tensor: + x, y, w, h = convert_bounding_box_format( + bounding_box, + old_format=BoundingBoxFormat.XYXY, + new_format=BoundingBoxFormat.XYWH, + ).unbind(-1) + x = image_size[1] - (x + w) + return convert_bounding_box_format( + torch.stack((x, y, w, h), dim=-1), + old_format=BoundingBoxFormat.XYWH, + new_format=BoundingBoxFormat.XYXY, + ) + + +_resize_image = _F.resize + + +def resize_image( + image: torch.Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + new_height, new_width = size + num_channels, old_height, old_width = image.shape[-3:] + batch_shape = image.shape[:-3] + return _resize_image( + image.reshape((-1, num_channels, old_height, old_width)), + size=size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ).reshape(batch_shape + (num_channels, new_height, new_width)) + + +def resize_segmentation_mask( + segmentation_mask: torch.Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + return resize_image( + segmentation_mask, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias + ) + + +# TODO: handle max_size +def resize_bounding_box( + bounding_box: torch.Tensor, + *, + old_image_size: List[int], + new_image_size: List[int], +) -> torch.Tensor: + old_height, old_width = old_image_size + new_height, new_width = new_image_size + return ( + bounding_box.view(-1, 2, 2) + .mul(torch.tensor([new_width / old_width, new_height / old_height])) + .view(bounding_box.shape) + ) + + +center_crop_image = _F.center_crop + +resized_crop_image = _F.resized_crop + +affine_image = _F.affine + +rotate_image = _F.rotate diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/functional/_meta_conversion.py new file mode 100644 index 00000000000..484066a39ee --- /dev/null +++ b/torchvision/prototype/transforms/functional/_meta_conversion.py @@ -0,0 +1,69 @@ +import torch +from torchvision.prototype.features import BoundingBoxFormat, ColorSpace +from torchvision.transforms.functional_tensor import rgb_to_grayscale as _rgb_to_grayscale + + +def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: + xyxy = xywh.clone() + xyxy[..., 2:] += xyxy[..., :2] + return xyxy + + +def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: + xywh = xyxy.clone() + xywh[..., 2:] -= xywh[..., :2] + return xywh + + +def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor: + cx, cy, w, h = torch.unbind(cxcywh, dim=-1) + x1 = cx - 0.5 * w + y1 = cy - 0.5 * h + x2 = cx + 0.5 * w + y2 = cy + 0.5 * h + return torch.stack((x1, y1, x2, y2), dim=-1) + + +def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: + x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1) + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + w = x2 - x1 + h = y2 - y1 + return torch.stack((cx, cy, w, h), dim=-1) + + +def convert_bounding_box_format( + bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat +) -> torch.Tensor: + if new_format == old_format: + return bounding_box + + if old_format == BoundingBoxFormat.XYWH: + bounding_box = _xywh_to_xyxy(bounding_box) + elif old_format == BoundingBoxFormat.CXCYWH: + bounding_box = _cxcywh_to_xyxy(bounding_box) + + if new_format == BoundingBoxFormat.XYWH: + bounding_box = _xyxy_to_xywh(bounding_box) + elif new_format == BoundingBoxFormat.CXCYWH: + bounding_box = _xyxy_to_cxcywh(bounding_box) + + return bounding_box + + +def _grayscale_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: + return grayscale.expand(3, 1, 1) + + +def convert_color_space(image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace) -> torch.Tensor: + if new_color_space == old_color_space: + return image + + if old_color_space == ColorSpace.GRAYSCALE: + image = _grayscale_to_rgb(image) + + if new_color_space == ColorSpace.GRAYSCALE: + image = _rgb_to_grayscale(image) + + return image diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py new file mode 100644 index 00000000000..de148ab194a --- /dev/null +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -0,0 +1,4 @@ +from torchvision.transforms import functional as _F + + +normalize_image = _F.normalize diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py new file mode 100644 index 00000000000..ed355ab5eae --- /dev/null +++ b/torchvision/prototype/transforms/functional/_type_conversion.py @@ -0,0 +1,25 @@ +import unittest.mock +from typing import Dict, Any, Tuple + +import numpy as np +import PIL.Image +import torch +from torch.nn.functional import one_hot +from torchvision.io.video import read_video +from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer + + +def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor: + image = torch.as_tensor(np.array(PIL.Image.open(ReadOnlyTensorBuffer(encoded_image)), copy=True)) + if image.ndim == 2: + image = image.unsqueeze(2) + return image.permute(2, 0, 1) + + +def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + with unittest.mock.patch("torchvision.io.video.os.path.exists", return_value=True): + return read_video(ReadOnlyTensorBuffer(encoded_video)) # type: ignore[arg-type] + + +def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor: + return one_hot(label, num_classes=num_categories)