Skip to content

Transforms without dispatcher #5421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5e52ed2
add prototype transforms that don't need dispatchers
pmeier Feb 14, 2022
225ed47
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 14, 2022
fe82e94
cleanup
pmeier Feb 14, 2022
36f3e0d
remove legacy_transform decorator
pmeier Feb 15, 2022
757fbed
remove legacy classes
pmeier Feb 15, 2022
0ca3800
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 15, 2022
dc61271
remove explicit param passing
pmeier Feb 15, 2022
c7c4608
streamline extra_repr
pmeier Feb 15, 2022
13d49cb
remove obsolete ._supports() method
pmeier Feb 15, 2022
4771e25
cleanup
pmeier Feb 15, 2022
fb2077f
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 16, 2022
c393a43
remove Query
pmeier Feb 16, 2022
e7502ed
cleanup
pmeier Feb 16, 2022
fd752a6
fix tests
pmeier Feb 16, 2022
ea71c2c
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 17, 2022
283c474
kernels -> functional
pmeier Feb 21, 2022
b3c0452
move image size and num channels extraction to functional
pmeier Feb 21, 2022
c129dea
extend legacy function to extract image size and num channels
pmeier Feb 21, 2022
9b18c28
implement dispatching for auto augment
pmeier Feb 22, 2022
3348f89
fix auto augment dispatch
pmeier Feb 22, 2022
25e2ec0
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 22, 2022
90f9fa7
revert some naming changes
pmeier Feb 22, 2022
ddf28d2
remove ability to pass params to autoaugment
pmeier Feb 22, 2022
68bbb2b
fix legacy image size extraction
pmeier Feb 22, 2022
aa9f912
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 23, 2022
1587588
align prototype.transforms.functional with transforms.functional
pmeier Feb 24, 2022
41be83c
Merge branch 'transforms-without-dispatcher' of https://github.com/pm…
pmeier Feb 24, 2022
9fc2693
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 24, 2022
ab79215
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 24, 2022
7826ab3
cleanup
pmeier Feb 25, 2022
ced8bcf
fix image size and channels extraction
pmeier Feb 25, 2022
0017807
fix affine and rotate
pmeier Feb 25, 2022
23955b6
Merge branch 'transforms-without-dispatcher' of https://github.com/pm…
pmeier Feb 25, 2022
ed32288
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 25, 2022
71e4c56
revert image size to (width, height)
pmeier Feb 25, 2022
0943de0
Minor corrections
datumbox Feb 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 15 additions & 48 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import itertools

import PIL.Image
import pytest
import torch
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels
from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image

Expand All @@ -25,15 +24,6 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
yield bounding_box.data


INPUT_CREATIONS_FNS = {
features.Image: make_images,
features.BoundingBox: make_bounding_boxes,
features.OneHotLabel: make_one_hot_labels,
torch.Tensor: make_vanilla_tensor_images,
PIL.Image.Image: make_pil_images,
}


def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
Expand All @@ -52,15 +42,21 @@ def parametrize(transforms_with_inputs):
def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
for transform in transforms:
dispatcher = transform._DISPATCHER
if dispatcher is None:
continue

for type_ in dispatcher._kernels:
for creation_fn in [
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
]:
inputs = list(creation_fn())
try:
inputs = INPUT_CREATIONS_FNS[type_]()
except KeyError:
output = transform(inputs[0])
except Exception:
continue
else:
if output is inputs[0]:
continue

transforms_with_inputs.append((transform, inputs))

Expand All @@ -69,7 +65,7 @@ def parametrize_from_transforms(*transforms):

class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(),
transforms.RandomErasing(p=1.0),
transforms.HorizontalFlip(),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
Expand Down Expand Up @@ -141,35 +137,6 @@ def test_auto_augment(self, transform, input):
def test_normalize(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.ConvertColorSpace("grayscale"),
itertools.chain(
make_images(),
make_vanilla_tensor_images(color_spaces=["rgb"]),
make_pil_images(color_spaces=["rgb"]),
),
)
]
)
def test_convert_bounding_color_space(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"),
itertools.chain(
make_bounding_boxes(),
make_vanilla_tensor_bounding_boxes(formats=["xywh"]),
),
)
]
)
def test_convert_bounding_box_format(self, transform, input):
transform(input)

@parametrize(
[
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
import torch.testing
import torchvision.prototype.transforms.kernels as K
import torchvision.prototype.transforms.functional as F
from torch import jit
from torch.nn.functional import one_hot
from torchvision.prototype import features
Expand Down Expand Up @@ -134,10 +134,10 @@ def __init__(self, *args, **kwargs):
self.kwargs = kwargs


class KernelInfo:
class FunctionalInfo:
def __init__(self, name, *, sample_inputs_fn):
self.name = name
self.kernel = getattr(K, name)
self.functional = getattr(F, name)
self._sample_inputs_fn = sample_inputs_fn

def sample_inputs(self):
Expand All @@ -146,21 +146,21 @@ def sample_inputs(self):
def __call__(self, *args, **kwargs):
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput):
sample_input = args[0]
return self.kernel(*sample_input.args, **sample_input.kwargs)
return self.functional(*sample_input.args, **sample_input.kwargs)

return self.kernel(*args, **kwargs)
return self.functional(*args, **kwargs)


KERNEL_INFOS = []
FUNCTIONAL_INFOS = []


def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
FUNCTIONAL_INFOS.append(FunctionalInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
return sample_inputs_fn


@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_image():
def horizontal_flip_image_tensor():
for image in make_images():
yield SampleInput(image)

Expand All @@ -172,12 +172,12 @@ def horizontal_flip_bounding_box():


@register_kernel_info_from_sample_inputs_fn
def resize_image():
def resize_image_tensor():
for image, interpolation in itertools.product(
make_images(),
[
K.InterpolationMode.BILINEAR,
K.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.NEAREST,
],
):
height, width = image.shape[-2:]
Expand All @@ -200,20 +200,20 @@ def resize_bounding_box():


class TestKernelsCommon:
@pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name)
def test_scriptable(self, kernel_info):
jit.script(kernel_info.kernel)
@pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name)
def test_scriptable(self, functional_info):
jit.script(functional_info.functional)

@pytest.mark.parametrize(
("kernel_info", "sample_input"),
("functional_info", "sample_input"),
[
pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}")
for kernel_info in KERNEL_INFOS
for idx, sample_input in enumerate(kernel_info.sample_inputs())
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs())
],
)
def test_eager_vs_scripted(self, kernel_info, sample_input):
eager = kernel_info(sample_input)
scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs)
def test_eager_vs_scripted(self, functional_info, sample_input):
eager = functional_info(sample_input)
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)

torch.testing.assert_close(eager, scripted)
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# promote this out of the prototype state

# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import convert_bounding_box_format
from torchvision.prototype.transforms.functional import convert_bounding_box_format

if isinstance(format, str):
format = BoundingBoxFormat[format]
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def decode(self) -> Image:
# promote this out of the prototype state

# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import decode_image_with_pil
from torchvision.prototype.transforms.functional import decode_image_with_pil

return Image(decode_image_with_pil(self))

Expand Down
7 changes: 4 additions & 3 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import kernels # usort: skip
from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip

from . import functional # usort: skip

from ._transform import Transform # usort: skip

from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot
53 changes: 39 additions & 14 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
from typing import Any, Dict, Tuple

import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
Expand All @@ -12,9 +11,6 @@


class RandomErasing(Transform):
_DISPATCHER = F.erase
_FAIL_TYPES = {PIL.Image.Image, features.BoundingBox, features.SegmentationMask}

def __init__(
self,
p: float = 0.5,
Expand Down Expand Up @@ -45,8 +41,8 @@ def __init__(

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
img_h, img_w = F.get_image_size(image)
img_c = F.get_image_num_channels(image)
img_w, img_h = F.get_image_size(image)

if isinstance(self.value, (int, float)):
value = [self.value]
Expand Down Expand Up @@ -93,16 +89,24 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(zip("ijhwv", (i, j, h, w, v)))

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if torch.rand(1) >= self.p:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.erase_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return F.erase_image_tensor(input, **params)
else:
return input

return super()._transform(input, params)
def forward(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p:
return inputs if len(inputs) > 1 else inputs[0]

return super().forward(*inputs)


class RandomMixup(Transform):
_DISPATCHER = F.mixup
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}

def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
Expand All @@ -111,11 +115,20 @@ def __init__(self, *, alpha: float) -> None:
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.mixup_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
output = F.mixup_one_hot_label(input, **params)
return features.OneHotLabel.new_like(input, output)
else:
return input

class RandomCutmix(Transform):
_DISPATCHER = F.cutmix
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}

class RandomCutmix(Transform):
def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
Expand All @@ -125,7 +138,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(()))

image = query_image(sample)
H, W = F.get_image_size(image)
W, H = F.get_image_size(image)

r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
Expand All @@ -143,3 +156,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

return dict(box=box, lam_adjusted=lam_adjusted)

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.cutmix_image_tensor(input, box=params["box"])
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"])
return features.OneHotLabel.new_like(input, output)
else:
return input
Loading