Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3f6982e
add auto dispatch
pmeier Jan 25, 2022
587687e
fix missing arguments error message
pmeier Jan 31, 2022
3a4e53d
remove pil kernel for erase
pmeier Feb 1, 2022
35845b5
automate feature specific parameter detection
pmeier Feb 1, 2022
7778782
fix typos
pmeier Feb 1, 2022
019a0b6
cleanup dispatcher call
pmeier Feb 3, 2022
4cb2350
remove __torch_function__ from transform dispatch
pmeier Feb 4, 2022
1d9a827
Merge branch 'revamp-prototype-features-transforms' into transforms/d…
pmeier Feb 7, 2022
158a216
remove auto-generation
pmeier Feb 7, 2022
3ceb056
revert unrelated changes
pmeier Feb 7, 2022
b3cbfca
remove implements decorator
pmeier Feb 9, 2022
2a8345a
change register parameter order
pmeier Feb 9, 2022
9518cfb
change order of transforms for readability
pmeier Feb 9, 2022
05c0aef
Merge branch 'revamp-prototype-features-transforms'
pmeier Feb 9, 2022
a035286
add documentation for __torch_function__
pmeier Feb 9, 2022
772d651
Merge branch 'revamp-prototype-features-transforms' into transforms/d…
pmeier Feb 9, 2022
2d10741
fix mypy
pmeier Feb 9, 2022
f3d6522
inline check for support
pmeier Feb 9, 2022
b8cda56
refactor kernel registering process
pmeier Feb 9, 2022
9a45eb0
refactor dispatch to be a regular decorator
pmeier Feb 9, 2022
71af6f8
split kernels and dispatchers
pmeier Feb 9, 2022
4c13812
remove sentinels
pmeier Feb 9, 2022
f5df194
replace pass with ...
pmeier Feb 9, 2022
9014e20
appease mypy
pmeier Feb 9, 2022
0238184
make single kernel dispatchers more concise
pmeier Feb 10, 2022
22f4d29
make dispatcher signatures more generic
pmeier Feb 10, 2022
1cd2166
make kernel checking more strict
pmeier Feb 10, 2022
cca5040
revert doc changes
pmeier Feb 10, 2022
020dcfb
Merge branch 'revamp-prototype-features-transforms'
pmeier Feb 10, 2022
4216d91
address Franciscos comments
pmeier Feb 10, 2022
ecd1425
remove inplace
pmeier Feb 10, 2022
8771f40
rename kernel test module
pmeier Feb 10, 2022
0de4ba7
fix inplace
pmeier Feb 10, 2022
6ef6bf1
remove special casing for pil and vanilla tensors
pmeier Feb 10, 2022
886552c
address comments
pmeier Feb 10, 2022
c7785b0
update docs
pmeier Feb 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Transforming and augmenting images
==================================

.. currentmodule:: torchvision.prototype.transforms.functional

.. autofunction:: cutmix

.. currentmodule:: torchvision.transforms

Transforms are common image transformations available in the
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __new__(

def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.functional import convert_bounding_box_format
from torchvision.prototype.transforms.kernels import convert_bounding_box_format

if isinstance(format, str):
format = BoundingBoxFormat[format]
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_encoded.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def image_size(self) -> Tuple[int, int]:

def decode(self) -> Image:
# import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.functional import decode_image_with_pil
from torchvision.prototype.transforms.kernels import decode_image_with_pil

return Image(decode_image_with_pil(self))

Expand Down
44 changes: 42 additions & 2 deletions torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping

import torch
from torch._C import _TensorBase
from torch._C import _TensorBase, DisableTorchFunction


F = TypeVar("F", bound="Feature")
Expand Down Expand Up @@ -76,5 +76,45 @@ def new_like(
_metadata.update(metadata)
return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata)

@classmethod
def __torch_function__(
cls,
func: Callable[..., torch.Tensor],
types: Tuple[Type[torch.Tensor], ...],
args: Sequence[Any] = (),
kwargs: Optional[Mapping[str, Any]] = None,
) -> torch.Tensor:
"""For general information about how the __torch_function__ protocol works,
see https://pytorch.org/docs/stable/notes/extending.html#extending-torch

TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the
``args`` and ``kwargs`` of the original call.

The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature`
use case, this has two downsides:

1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e.
``return cls(func(*args, **kwargs))``, will fail for them.
2. For most operations, there is no way of knowing if the input type is still valid for the output.

For these reasons, the automatic output wrapping is turned off for most operators.

Exceptions to this are:

- :func:`torch.clone`
- :meth:`torch.Tensor.to`
"""
kwargs = kwargs or dict()
with DisableTorchFunction():
output = func(*args, **kwargs)

if func is torch.Tensor.clone:
return cls.new_like(args[0], output)
elif func is torch.Tensor.to:
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device)
else:
return output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the right place to put the comment but Github won't let me comment on the right spot. I think Feature is exposed publicly on the __init__.py file of the area. Given it's an internal class (unlike Image, BoundingBox etc), I think it's worth keeping private.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this does not need be supported in the first version, Feature should not be an internal class. We want users to be able to create their own custom features if it is useful for their use case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be private for now.

def __repr__(self) -> str:
return cast(str, torch.Tensor.__repr__(self)).replace("tensor", type(self).__name__)
5 changes: 2 additions & 3 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import functional
from .functional import InterpolationMode # usort: skip

from . import kernels # usort: skip
from . import functional # usort: skip
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
37 changes: 12 additions & 25 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label
from ._augment import erase, mixup, cutmix
from ._color import (
adjust_brightness_image,
adjust_contrast_image,
adjust_saturation_image,
adjust_sharpness_image,
posterize_image,
solarize_image,
autocontrast_image,
equalize_image,
invert_image,
adjust_brightness,
adjust_contrast,
adjust_saturation,
adjust_sharpness,
posterize,
solarize,
autocontrast,
equalize,
invert,
)
from ._geometry import (
horizontal_flip_bounding_box,
horizontal_flip_image,
resize_bounding_box,
resize_image,
resize_segmentation_mask,
center_crop_image,
resized_crop_image,
InterpolationMode,
affine_image,
rotate_image,
)
from ._meta_conversion import convert_color_space, convert_bounding_box_format
from ._misc import normalize_image
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate
from ._misc import normalize
78 changes: 47 additions & 31 deletions torchvision/prototype/transforms/functional/_augment.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,68 @@
from typing import Tuple
from typing import TypeVar

import torch
from torchvision.transforms import functional as _F
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K

from ._utils import dispatch

erase_image = _F.erase
T = TypeVar("T", bound=features.Feature)


def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor:
if not inplace:
input = input.clone()
@dispatch(
{
features.Image: K.erase_image,
},
)
def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool) -> T:
"""ADDME"""
...

input_rolled = input.roll(1, batch_dim)
return input.mul_(lam).add_(input_rolled.mul_(1 - lam))

@dispatch(
{
features.Image: K.mixup_image,
features.OneHotLabel: K.mixup_one_hot_label,
},
)
def mixup(input: T, *, lam: float, inplace: bool) -> T:
"""ADDME"""
...

def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")

return _mixup(image_batch, -4, lam, inplace)
@dispatch(
{
features.Image: K.cutmix_image,
features.OneHotLabel: K.cutmix_one_hot_label,
},
)
def cutmix(input: T, *, box: Tuple[int, int, int, int], lam_adjusted: float, inplace: bool) -> T:
"""Perform the CutMix operation as introduced in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_.
Dispatch to the corresponding kernels happens according to this table:
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")
.. table::
:widths: 30 70
return _mixup(one_hot_label_batch, -2, lam, inplace)
==================================================== ================================================================
:class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image`
:class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label`
==================================================== ================================================================
Please refer to the kernel documentations for a detailed explanation of the functionality and parameters.
def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor:
if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
.. note::
if not inplace:
image_batch = image_batch.clone()
The ``box`` parameter is only required for inputs of type
x1, y1, x2, y2 = box
image_rolled = image_batch.roll(1, -4)
- :class:`~torchvision.prototype.features.Image`
image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
return image_batch
.. note::
The ``lam_adjusted`` parameter is only required for inputs of type
def cutmix_one_hot_label(
one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False
) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
raise ValueError("Need a batch of one hot labels")

return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace)
- :class:`~torchvision.prototype.features.OneHotLabel`
"""
...
98 changes: 88 additions & 10 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,98 @@
from torchvision.transforms import functional as _F
from typing import TypeVar

from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K

adjust_brightness_image = _F.adjust_brightness
from ._utils import dispatch

adjust_saturation_image = _F.adjust_saturation
T = TypeVar("T", bound=features.Feature)

adjust_contrast_image = _F.adjust_contrast

adjust_sharpness_image = _F.adjust_sharpness
@dispatch(
{
features.Image: K.adjust_brightness_image,
}
)
def adjust_brightness(input: T, *, brightness_factor: float) -> T:
"""ADDME"""
...

posterize_image = _F.posterize

solarize_image = _F.solarize
@dispatch(
{
features.Image: K.adjust_saturation_image,
}
)
def adjust_saturation(input: T, *, saturation_factor: float) -> T:
"""ADDME"""
...

autocontrast_image = _F.autocontrast

equalize_image = _F.equalize
@dispatch(
{
features.Image: K.adjust_contrast_image,
}
)
def adjust_contrast(input: T, *, contrast_factor: float) -> T:
"""ADDME"""
...

invert_image = _F.invert

@dispatch(
{
features.Image: K.adjust_sharpness_image,
}
)
def adjust_sharpness(input: T, *, sharpness_factor: float) -> T:
"""ADDME"""
...


@dispatch(
{
features.Image: K.posterize_image,
}
)
def posterize(input: T, *, bits: int) -> T:
"""ADDME"""
...


@dispatch(
{
features.Image: K.solarize_image,
}
)
def solarize(input: T, *, threshold: float) -> T:
"""ADDME"""
...


@dispatch(
{
features.Image: K.autocontrast_image,
}
)
def autocontrast(input: T) -> T:
"""ADDME"""
...


@dispatch(
{
features.Image: K.equalize_image,
}
)
def equalize(input: T) -> T:
"""ADDME"""
...


@dispatch(
{
features.Image: K.invert_image,
}
)
def invert(input: T) -> T:
"""ADDME"""
...
Loading