Skip to content

Proto transform cleanup #6408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_random_resized_crop(self, transform, input):
@parametrize(
[
(
transforms.ConvertImageColorSpace(color_space=new_color_space, old_color_space=old_color_space),
transforms.ConvertColorSpace(color_space=new_color_space, old_color_space=old_color_space),
itertools.chain.from_iterable(
[
fn(color_spaces=[old_color_space])
Expand All @@ -223,7 +223,7 @@ def test_random_resized_crop(self, transform, input):
)
]
)
def test_convert_image_color_space(self, transform, input):
def test_convertolor_space(self, transform, input):
transform(input)


Expand Down
8 changes: 2 additions & 6 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,13 @@ def new_like(
)

def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state

# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.functional import convert_bounding_box_format
from torchvision.prototype.transforms import functional as _F

if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper())

return BoundingBox.new_like(
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
self, _F.convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
)

def horizontal_flip(self) -> BoundingBox:
Expand Down
14 changes: 14 additions & 0 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
else:
return ColorSpace.OTHER

def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
from torchvision.prototype.transforms import functional as _F

if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())

return Image.new_like(
self,
_F.convert_color_space_image_tensor(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)

def show(self) -> None:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
ScaleJitter,
TenCrop,
)
from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype
from ._meta import ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor

Expand Down
79 changes: 41 additions & 38 deletions torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,30 +84,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return output


class _RandomChannelShuffle(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels, _, _ = get_image_dimensions(image)
return dict(permutation=torch.randperm(num_channels))

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt

image = inpt
if isinstance(inpt, PIL.Image.Image):
image = _F.pil_to_tensor(image)

output = image[..., params["permutation"], :, :]

if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(inpt, PIL.Image.Image):
output = _F.to_pil_image(output)

return output


class RandomPhotometricDistort(Transform):
def __init__(
self,
Expand All @@ -118,35 +94,62 @@ def __init__(
p: float = 0.5,
):
super().__init__()
self._brightness = ColorJitter(brightness=brightness)
self._contrast = ColorJitter(contrast=contrast)
self._hue = ColorJitter(hue=hue)
self._saturation = ColorJitter(saturation=saturation)
self._channel_shuffle = _RandomChannelShuffle()
self.brightness = brightness
self.contrast = contrast
self.hue = hue
self.saturation = saturation
self.p = p

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
num_channels, _, _ = get_image_dimensions(image)
return dict(
zip(
["brightness", "contrast1", "saturation", "hue", "contrast2", "channel_shuffle"],
["brightness", "contrast1", "saturation", "hue", "contrast2"],
torch.rand(6) < self.p,
),
contrast_before=torch.rand(()) < 0.5,
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
)

def _permute_channels(self, inpt: Any, *, permutation: torch.Tensor) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt

image = inpt
if isinstance(inpt, PIL.Image.Image):
image = _F.pil_to_tensor(image)

output = image[..., permutation, :, :]

if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER)
elif isinstance(inpt, PIL.Image.Image):
output = _F.to_pil_image(output)

return output

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness"]:
inpt = self._brightness(inpt)
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
)
if params["contrast1"] and params["contrast_before"]:
inpt = self._contrast(inpt)
if params["saturation"]:
inpt = self._saturation(inpt)
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["saturation"]:
inpt = self._saturation(inpt)
inpt = F.adjust_saturation(
inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1])
)
if params["hue"]:
inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1]))
if params["contrast2"] and not params["contrast_before"]:
inpt = self._contrast(inpt)
if params["channel_shuffle"]:
inpt = self._channel_shuffle(inpt)
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["channel_permutation"]:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
return inpt


Expand Down
13 changes: 5 additions & 8 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import numpy as np
import PIL.Image
import torch
import torchvision.prototype.transforms.functional as F
from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal

from ._meta import ConvertImageColorSpace
from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor

Expand Down Expand Up @@ -90,13 +90,11 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:

super().__init__()
self.num_output_channels = num_output_channels
self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)
self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = self._rgb_to_gray(inpt)
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
if self.num_output_channels == 3:
output = self._gray_to_rgb(output)
output = F.convert_color_space(inpt, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
return output


Expand All @@ -115,8 +113,7 @@ def __init__(self, p: float = 0.1) -> None:
)

super().__init__(p=p)
self._rgb_to_gray = ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)
self._gray_to_rgb = ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._gray_to_rgb(self._rgb_to_gray(inpt))
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
43 changes: 15 additions & 28 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import collections.abc
import math
import numbers
import warnings
Expand Down Expand Up @@ -180,9 +179,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
return MultiCropResult(features.Image.new_like(inpt, o) for o in output)
elif is_simple_tensor(inpt):
return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size))
return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip))
elif isinstance(inpt, PIL.Image.Image):
return MultiCropResult(F.ten_crop_image_pil(inpt, self.size))
return MultiCropResult(F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip))
else:
return inpt

Expand All @@ -194,31 +193,19 @@ def forward(self, *inputs: Any) -> Any:


class BatchMultiCrop(Transform):
def forward(self, *inputs: Any) -> Any:
# This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one
# significant difference:
# Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from
# the sequence case.
def apply_recursively(obj: Any) -> Any:
if isinstance(obj, MultiCropResult):
crops = obj
if isinstance(obj[0], PIL.Image.Image):
crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment]

batch = torch.stack(crops)

if isinstance(obj[0], features.Image):
batch = features.Image.new_like(obj[0], batch)

return batch
elif isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
return [apply_recursively(item) for item in obj]
elif isinstance(obj, collections.abc.Mapping):
return {key: apply_recursively(item) for key, item in obj.items()}
else:
return obj

return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
_transformed_types = (MultiCropResult,)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
crops = inpt
if isinstance(inpt[0], PIL.Image.Image):
crops = [pil_to_tensor(crop) for crop in crops]

batch = torch.stack(crops)

if isinstance(inpt[0], features.Image):
batch = features.Image.new_like(inpt[0], batch)

return batch


def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
Expand Down
31 changes: 11 additions & 20 deletions torchvision/prototype/transforms/_meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Optional, Union

import PIL.Image

import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
Expand Down Expand Up @@ -39,11 +40,15 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt


class ConvertImageColorSpace(Transform):
class ConvertColorSpace(Transform):
# F.convert_color_space does NOT handle `_Feature`'s in general
_transformed_types = (torch.Tensor, features.Image, PIL.Image.Image)

def __init__(
self,
color_space: Union[str, features.ColorSpace],
old_color_space: Optional[Union[str, features.ColorSpace]] = None,
copy: bool = True,
) -> None:
super().__init__()

Expand All @@ -55,23 +60,9 @@ def __init__(
old_color_space = features.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space

self.copy = copy

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image):
output = F.convert_image_color_space_tensor(
inpt, old_color_space=inpt.color_space, new_color_space=self.color_space
)
return features.Image.new_like(inpt, output, color_space=self.color_space)
elif is_simple_tensor(inpt):
if self.old_color_space is None:
raise RuntimeError(
f"In order to convert simple tensor images, `{type(self).__name__}(...)` "
f"needs to be constructed with the `old_color_space=...` parameter."
)

return F.convert_image_color_space_tensor(
inpt, old_color_space=self.old_color_space, new_color_space=self.color_space
)
elif isinstance(inpt, PIL.Image.Image):
return F.convert_image_color_space_pil(inpt, color_space=self.color_space)
else:
return inpt
return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
)
5 changes: 3 additions & 2 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from torchvision.transforms import InterpolationMode # usort: skip
from ._meta import (
convert_bounding_box_format,
convert_image_color_space_tensor,
convert_image_color_space_pil,
convert_color_space_image_tensor,
convert_color_space_image_pil,
convert_color_space,
) # usort: skip

from ._augment import erase_image_pil, erase_image_tensor
Expand Down
26 changes: 22 additions & 4 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import PIL.Image
import torch
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace, Image
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT

get_dimensions_image_tensor = _FT.get_dimensions
Expand Down Expand Up @@ -91,7 +91,7 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
_rgb_to_gray = _FT.rgb_to_grayscale


def convert_image_color_space_tensor(
def convert_color_space_image_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True
) -> torch.Tensor:
if new_color_space == old_color_space:
Expand Down Expand Up @@ -141,7 +141,7 @@ def convert_image_color_space_tensor(
}


def convert_image_color_space_pil(
def convert_color_space_image_pil(
image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True
) -> PIL.Image.Image:
old_mode = image.mode
Expand All @@ -154,3 +154,21 @@ def convert_image_color_space_pil(
return image

return image.convert(new_mode)


def convert_color_space(
inpt: Any, *, color_space: ColorSpace, old_color_space: Optional[ColorSpace] = None, copy: bool = True
) -> Any:
if isinstance(inpt, Image):
return inpt.to_color_space(color_space, copy=copy)
elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space, copy=copy)
else:
if old_color_space is None:
raise RuntimeError(
"In order to convert the color space of simple tensor images, "
"the `old_color_space=...` parameter needs to be passed."
)
return convert_color_space_image_tensor(
inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy
)
Loading