-
Notifications
You must be signed in to change notification settings - Fork 7.2k
add automatic feature type dispatch to functional transforms #5323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
3f6982e
587687e
3a4e53d
35845b5
7778782
019a0b6
4cb2350
1d9a827
158a216
3ceb056
b3cbfca
2a8345a
9518cfb
05c0aef
a035286
772d651
2d10741
f3d6522
b8cda56
9a45eb0
71af6f8
4c13812
f5df194
9014e20
0238184
22f4d29
1cd2166
cca5040
020dcfb
4216d91
ecd1425
8771f40
0de4ba7
6ef6bf1
886552c
c7785b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable | ||
from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping | ||
|
||
import torch | ||
from torch._C import _TensorBase | ||
from torch._C import _TensorBase, DisableTorchFunction | ||
|
||
|
||
F = TypeVar("F", bound="Feature") | ||
|
@@ -76,5 +76,45 @@ def new_like( | |
_metadata.update(metadata) | ||
return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata) | ||
|
||
@classmethod | ||
def __torch_function__( | ||
cls, | ||
func: Callable[..., torch.Tensor], | ||
types: Tuple[Type[torch.Tensor], ...], | ||
args: Sequence[Any] = (), | ||
kwargs: Optional[Mapping[str, Any]] = None, | ||
) -> torch.Tensor: | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""For general information about how the __torch_function__ protocol works, | ||
see https://pytorch.org/docs/stable/notes/extending.html#extending-torch | ||
|
||
TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the | ||
``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the | ||
``args`` and ``kwargs`` of the original call. | ||
|
||
The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature` | ||
use case, this has two downsides: | ||
|
||
1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e. | ||
``return cls(func(*args, **kwargs))``, will fail for them. | ||
2. For most operations, there is no way of knowing if the input type is still valid for the output. | ||
|
||
For these reasons, the automatic output wrapping is turned off for most operators. | ||
|
||
Exceptions to this are: | ||
|
||
- :func:`torch.clone` | ||
- :meth:`torch.Tensor.to` | ||
""" | ||
kwargs = kwargs or dict() | ||
with DisableTorchFunction(): | ||
output = func(*args, **kwargs) | ||
|
||
if func is torch.Tensor.clone: | ||
return cls.new_like(args[0], output) | ||
elif func is torch.Tensor.to: | ||
return cls.new_like(args[0], output, dtype=output.dtype, device=output.device) | ||
else: | ||
return output | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not the right place to put the comment but Github won't let me comment on the right spot. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although this does not need be supported in the first version, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be private for now. |
||
def __repr__(self) -> str: | ||
return cast(str, torch.Tensor.__repr__(self)).replace("tensor", type(self).__name__) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from . import functional | ||
from .functional import InterpolationMode # usort: skip | ||
|
||
from . import kernels # usort: skip | ||
from . import functional # usort: skip | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,14 @@ | ||
from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label | ||
from ._augment import erase, mixup, cutmix | ||
from ._color import ( | ||
adjust_brightness_image, | ||
adjust_contrast_image, | ||
adjust_saturation_image, | ||
adjust_sharpness_image, | ||
posterize_image, | ||
solarize_image, | ||
autocontrast_image, | ||
equalize_image, | ||
invert_image, | ||
adjust_brightness, | ||
adjust_contrast, | ||
adjust_saturation, | ||
adjust_sharpness, | ||
posterize, | ||
solarize, | ||
autocontrast, | ||
equalize, | ||
invert, | ||
) | ||
from ._geometry import ( | ||
horizontal_flip_bounding_box, | ||
horizontal_flip_image, | ||
resize_bounding_box, | ||
resize_image, | ||
resize_segmentation_mask, | ||
center_crop_image, | ||
resized_crop_image, | ||
InterpolationMode, | ||
affine_image, | ||
rotate_image, | ||
) | ||
from ._meta_conversion import convert_color_space, convert_bounding_box_format | ||
from ._misc import normalize_image | ||
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot | ||
from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate | ||
from ._misc import normalize |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,68 @@ | ||
from typing import Tuple | ||
from typing import TypeVar | ||
|
||
import torch | ||
from torchvision.transforms import functional as _F | ||
from torchvision.prototype import features | ||
from torchvision.prototype.transforms import kernels as K | ||
|
||
from ._utils import dispatch | ||
|
||
erase_image = _F.erase | ||
T = TypeVar("T", bound=features.Feature) | ||
|
||
|
||
def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: | ||
if not inplace: | ||
input = input.clone() | ||
@dispatch( | ||
{ | ||
features.Image: K.erase_image, | ||
}, | ||
) | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool) -> T: | ||
pmeier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"""ADDME""" | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
... | ||
|
||
input_rolled = input.roll(1, batch_dim) | ||
return input.mul_(lam).add_(input_rolled.mul_(1 - lam)) | ||
|
||
@dispatch( | ||
{ | ||
features.Image: K.mixup_image, | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
features.OneHotLabel: K.mixup_one_hot_label, | ||
}, | ||
) | ||
def mixup(input: T, *, lam: float, inplace: bool) -> T: | ||
"""ADDME""" | ||
... | ||
|
||
def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: | ||
if image_batch.ndim < 4: | ||
raise ValueError("Need a batch of images") | ||
|
||
return _mixup(image_batch, -4, lam, inplace) | ||
@dispatch( | ||
{ | ||
features.Image: K.cutmix_image, | ||
features.OneHotLabel: K.cutmix_one_hot_label, | ||
}, | ||
) | ||
def cutmix(input: T, *, box: Tuple[int, int, int, int], lam_adjusted: float, inplace: bool) -> T: | ||
"""Perform the CutMix operation as introduced in the paper | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_. | ||
Dispatch to the corresponding kernels happens according to this table: | ||
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: | ||
if one_hot_label_batch.ndim < 2: | ||
raise ValueError("Need a batch of one hot labels") | ||
.. table:: | ||
:widths: 30 70 | ||
return _mixup(one_hot_label_batch, -2, lam, inplace) | ||
==================================================== ================================================================ | ||
:class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image` | ||
:class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label` | ||
==================================================== ================================================================ | ||
Please refer to the kernel documentations for a detailed explanation of the functionality and parameters. | ||
def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: | ||
if image_batch.ndim < 4: | ||
raise ValueError("Need a batch of images") | ||
.. note:: | ||
if not inplace: | ||
image_batch = image_batch.clone() | ||
The ``box`` parameter is only required for inputs of type | ||
x1, y1, x2, y2 = box | ||
image_rolled = image_batch.roll(1, -4) | ||
- :class:`~torchvision.prototype.features.Image` | ||
image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] | ||
return image_batch | ||
.. note:: | ||
The ``lam_adjusted`` parameter is only required for inputs of type | ||
def cutmix_one_hot_label( | ||
one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False | ||
) -> torch.Tensor: | ||
if one_hot_label_batch.ndim < 2: | ||
raise ValueError("Need a batch of one hot labels") | ||
|
||
return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) | ||
- :class:`~torchvision.prototype.features.OneHotLabel` | ||
""" | ||
... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,98 @@ | ||
from torchvision.transforms import functional as _F | ||
from typing import TypeVar | ||
|
||
from torchvision.prototype import features | ||
from torchvision.prototype.transforms import kernels as K | ||
|
||
adjust_brightness_image = _F.adjust_brightness | ||
from ._utils import dispatch | ||
|
||
adjust_saturation_image = _F.adjust_saturation | ||
T = TypeVar("T", bound=features.Feature) | ||
|
||
adjust_contrast_image = _F.adjust_contrast | ||
|
||
adjust_sharpness_image = _F.adjust_sharpness | ||
@dispatch( | ||
{ | ||
features.Image: K.adjust_brightness_image, | ||
} | ||
) | ||
def adjust_brightness(input: T, *, brightness_factor: float) -> T: | ||
"""ADDME""" | ||
... | ||
|
||
posterize_image = _F.posterize | ||
|
||
solarize_image = _F.solarize | ||
@dispatch( | ||
{ | ||
features.Image: K.adjust_saturation_image, | ||
} | ||
) | ||
def adjust_saturation(input: T, *, saturation_factor: float) -> T: | ||
"""ADDME""" | ||
... | ||
|
||
autocontrast_image = _F.autocontrast | ||
|
||
equalize_image = _F.equalize | ||
@dispatch( | ||
{ | ||
features.Image: K.adjust_contrast_image, | ||
} | ||
) | ||
def adjust_contrast(input: T, *, contrast_factor: float) -> T: | ||
"""ADDME""" | ||
... | ||
|
||
invert_image = _F.invert | ||
|
||
@dispatch( | ||
{ | ||
features.Image: K.adjust_sharpness_image, | ||
} | ||
) | ||
def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: | ||
"""ADDME""" | ||
... | ||
|
||
|
||
@dispatch( | ||
{ | ||
features.Image: K.posterize_image, | ||
} | ||
) | ||
def posterize(input: T, *, bits: int) -> T: | ||
"""ADDME""" | ||
... | ||
|
||
|
||
@dispatch( | ||
{ | ||
features.Image: K.solarize_image, | ||
} | ||
) | ||
def solarize(input: T, *, threshold: float) -> T: | ||
"""ADDME""" | ||
... | ||
|
||
|
||
@dispatch( | ||
{ | ||
features.Image: K.autocontrast_image, | ||
} | ||
) | ||
def autocontrast(input: T) -> T: | ||
"""ADDME""" | ||
... | ||
|
||
|
||
@dispatch( | ||
{ | ||
features.Image: K.equalize_image, | ||
} | ||
) | ||
def equalize(input: T) -> T: | ||
"""ADDME""" | ||
... | ||
|
||
|
||
@dispatch( | ||
{ | ||
features.Image: K.invert_image, | ||
} | ||
) | ||
def invert(input: T) -> T: | ||
"""ADDME""" | ||
... |
Uh oh!
There was an error while loading. Please reload this page.