-
Notifications
You must be signed in to change notification settings - Fork 7.1k
readd functional transforms API to prototype #5295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4b327ba
eaac4f2
079a005
c78f6f7
45ce1f9
0b81a23
265b3d4
2a7983e
1392e78
249cfb3
e455810
8c96fc0
a77bf1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import functools | ||
import itertools | ||
|
||
import pytest | ||
import torch.testing | ||
import torchvision.prototype.transforms.functional as F | ||
from torch import jit | ||
from torchvision.prototype import features | ||
|
||
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") | ||
|
||
|
||
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32): | ||
size = size or torch.randint(16, 33, (2,)).tolist() | ||
|
||
if isinstance(color_space, str): | ||
color_space = features.ColorSpace[color_space] | ||
num_channels = { | ||
features.ColorSpace.GRAYSCALE: 1, | ||
features.ColorSpace.RGB: 3, | ||
}[color_space] | ||
|
||
shape = (*extra_dims, num_channels, *size) | ||
if dtype.is_floating_point: | ||
data = torch.rand(shape, dtype=dtype) | ||
else: | ||
data = torch.randint(0, torch.iinfo(dtype).max, shape, dtype=dtype) | ||
return features.Image(data, color_space=color_space) | ||
|
||
|
||
make_grayscale_image = functools.partial(make_image, color_space=features.ColorSpace.GRAYSCALE) | ||
make_rgb_image = functools.partial(make_image, color_space=features.ColorSpace.RGB) | ||
|
||
|
||
def make_images( | ||
sizes=((16, 16), (7, 33), (31, 9)), | ||
color_spaces=(features.ColorSpace.GRAYSCALE, features.ColorSpace.RGB), | ||
dtypes=(torch.float32, torch.uint8), | ||
extra_dims=((4,), (2, 3)), | ||
): | ||
for size, color_space, dtype in itertools.product(sizes, color_spaces, dtypes): | ||
yield make_image(size, color_space=color_space) | ||
|
||
for color_space, extra_dims_ in itertools.product(color_spaces, extra_dims): | ||
yield make_image(color_space=color_space, extra_dims=extra_dims_) | ||
|
||
|
||
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): | ||
low, high = torch.broadcast_tensors( | ||
*[torch.as_tensor(arg) for arg in ((0, arg1) if arg2 is None else (arg1, arg2))] | ||
) | ||
try: | ||
return torch.stack( | ||
[ | ||
torch.randint(low_scalar, high_scalar, (), **kwargs) | ||
for low_scalar, high_scalar in zip(low.flatten().tolist(), high.flatten().tolist()) | ||
] | ||
).reshape(low.shape) | ||
except RuntimeError as error: | ||
raise error | ||
|
||
|
||
def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch.int64): | ||
if isinstance(format, str): | ||
format = features.BoundingBoxFormat[format] | ||
|
||
height, width = image_size | ||
|
||
if format == features.BoundingBoxFormat.XYXY: | ||
x1 = torch.randint(0, width // 2, extra_dims) | ||
y1 = torch.randint(0, height // 2, extra_dims) | ||
x2 = randint_with_tensor_bounds(x1 + 1, width - x1) + x1 | ||
y2 = randint_with_tensor_bounds(y1 + 1, height - y1) + y1 | ||
parts = (x1, y1, x2, y2) | ||
elif format == features.BoundingBoxFormat.XYWH: | ||
x = torch.randint(0, width // 2, extra_dims) | ||
y = torch.randint(0, height // 2, extra_dims) | ||
w = randint_with_tensor_bounds(1, width - x) | ||
h = randint_with_tensor_bounds(1, height - y) | ||
parts = (x, y, w, h) | ||
elif format == features.BoundingBoxFormat.CXCYWH: | ||
cx = torch.randint(1, width - 1, ()) | ||
cy = torch.randint(1, height - 1, ()) | ||
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) | ||
h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1) | ||
parts = (cx, cy, w, h) | ||
else: # format == features.BoundingBoxFormat._SENTINEL: | ||
raise ValueError() | ||
|
||
return features.BoundingBox(torch.stack(parts, dim=-1).to(dtype), format=format, image_size=image_size) | ||
|
||
|
||
make_xyxy_bounding_box = functools.partial(make_bounding_box, format=features.BoundingBoxFormat.XYXY) | ||
|
||
|
||
def make_bounding_boxes( | ||
formats=(features.BoundingBoxFormat.XYXY, features.BoundingBoxFormat.XYWH, features.BoundingBoxFormat.CXCYWH), | ||
image_sizes=((32, 32),), | ||
dtypes=(torch.int64, torch.float32), | ||
extra_dims=((4,), (2, 3)), | ||
): | ||
for format, image_size, dtype in itertools.product(formats, image_sizes, dtypes): | ||
yield make_bounding_box(format=format, image_size=image_size, dtype=dtype) | ||
|
||
for format, extra_dims_ in itertools.product(formats, extra_dims): | ||
yield make_bounding_box(format=format, extra_dims=extra_dims_) | ||
|
||
|
||
class SampleInput: | ||
def __init__(self, *args, **kwargs): | ||
self.args = args | ||
self.kwargs = kwargs | ||
|
||
|
||
class KernelInfo: | ||
def __init__(self, name, *, sample_inputs_fn): | ||
self.name = name | ||
self.kernel = getattr(F, name) | ||
self._sample_inputs_fn = sample_inputs_fn | ||
|
||
def sample_inputs(self): | ||
yield from self._sample_inputs_fn() | ||
|
||
def __call__(self, *args, **kwargs): | ||
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput): | ||
sample_input = args[0] | ||
return self.kernel(*sample_input.args, **sample_input.kwargs) | ||
|
||
return self.kernel(*args, **kwargs) | ||
|
||
|
||
KERNEL_INFOS = [] | ||
|
||
|
||
def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): | ||
KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) | ||
return sample_inputs_fn | ||
|
||
|
||
@register_kernel_info_from_sample_inputs_fn | ||
def horizontal_flip_image(): | ||
for image in make_images(): | ||
yield SampleInput(image) | ||
|
||
|
||
@register_kernel_info_from_sample_inputs_fn | ||
def horizontal_flip_bounding_box(): | ||
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): | ||
yield SampleInput(bounding_box, image_size=bounding_box.image_size) | ||
|
||
|
||
@register_kernel_info_from_sample_inputs_fn | ||
def resize_image(): | ||
for image, interpolation in itertools.product( | ||
make_images(), | ||
[ | ||
F.InterpolationMode.BILINEAR, | ||
F.InterpolationMode.NEAREST, | ||
], | ||
): | ||
height, width = image.shape[-2:] | ||
for size in [ | ||
(height, width), | ||
(int(height * 0.75), int(width * 1.25)), | ||
]: | ||
yield SampleInput(image, size=size, interpolation=interpolation) | ||
|
||
|
||
@register_kernel_info_from_sample_inputs_fn | ||
def resize_bounding_box(): | ||
for bounding_box in make_bounding_boxes(): | ||
height, width = bounding_box.image_size | ||
for new_image_size in [ | ||
(height, width), | ||
(int(height * 0.75), int(width * 1.25)), | ||
]: | ||
yield SampleInput(bounding_box, old_image_size=bounding_box.image_size, new_image_size=new_image_size) | ||
|
||
|
||
class TestKernelsCommon: | ||
@pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name) | ||
def test_scriptable(self, kernel_info): | ||
jit.script(kernel_info.kernel) | ||
|
||
@pytest.mark.parametrize( | ||
("kernel_info", "sample_input"), | ||
[ | ||
pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}") | ||
for kernel_info in KERNEL_INFOS | ||
for idx, sample_input in enumerate(kernel_info.sample_inputs()) | ||
Comment on lines
+189
to
+190
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of the 2 manual Also, is there stark difference between passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We can't use two separate parametrizations, since the inner loop depends on the outer.
Let's look into this more after #5295 (comment). |
||
], | ||
) | ||
def test_eager_vs_scripted(self, kernel_info, sample_input): | ||
eager = kernel_info(sample_input) | ||
scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs) | ||
|
||
torch.testing.assert_close(eager, scripted) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,3 +36,14 @@ def __new__( | |
bounding_box._metadata.update(dict(format=format, image_size=image_size)) | ||
|
||
return bounding_box | ||
|
||
def to_format(self, format: Union[str, BoundingBoxFormat]) -> "BoundingBox": | ||
# import at runtime to avoid cyclic imports | ||
from torchvision.prototype.transforms.functional import convert_bounding_box_format | ||
Comment on lines
+41
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe, this is a sign to redesign the structure and put this method here or in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was a design choice to have all transforming functions in If we relax the design choice to all transforming functions need to be present in |
||
|
||
if isinstance(format, str): | ||
format = BoundingBoxFormat[format] | ||
|
||
return BoundingBox.new_like( | ||
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
from . import functional | ||
from .functional import InterpolationMode # usort: skip | ||
|
||
from ._transform import Transform | ||
from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip | ||
|
||
from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop | ||
from ._misc import Identity, Normalize | ||
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label | ||
from ._color import ( | ||
adjust_brightness_image, | ||
adjust_contrast_image, | ||
adjust_saturation_image, | ||
adjust_sharpness_image, | ||
posterize_image, | ||
solarize_image, | ||
autocontrast_image, | ||
equalize_image, | ||
invert_image, | ||
) | ||
from ._geometry import ( | ||
horizontal_flip_bounding_box, | ||
horizontal_flip_image, | ||
resize_bounding_box, | ||
resize_image, | ||
resize_segmentation_mask, | ||
center_crop_image, | ||
resized_crop_image, | ||
InterpolationMode, | ||
affine_image, | ||
rotate_image, | ||
) | ||
from ._meta_conversion import convert_color_space, convert_bounding_box_format | ||
from ._misc import normalize_image | ||
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Tuple | ||
|
||
import torch | ||
from torchvision.transforms import functional as _F | ||
|
||
|
||
erase_image = _F.erase | ||
|
||
|
||
def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: | ||
if not inplace: | ||
input = input.clone() | ||
|
||
input_rolled = input.roll(1, batch_dim) | ||
return input.mul_(lam).add_(input_rolled.mul_(1 - lam)) | ||
|
||
|
||
def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: | ||
return _mixup(image_batch, -4, lam, inplace) | ||
|
||
|
||
def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: | ||
return _mixup(one_hot_label_batch, -2, lam, inplace) | ||
|
||
|
||
def cutmix_image(image: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: | ||
if not inplace: | ||
image = image.clone() | ||
|
||
x1, y1, x2, y2 = box | ||
image_rolled = image.roll(1, -4) | ||
|
||
image[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] | ||
return image | ||
|
||
|
||
def cutmix_one_hot_label( | ||
one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False | ||
) -> torch.Tensor: | ||
return mixup_one_hot_label(one_hot_label_batch, lam=lam_adjusted, inplace=inplace) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from torchvision.transforms import functional as _F | ||
|
||
|
||
adjust_brightness_image = _F.adjust_brightness | ||
|
||
adjust_saturation_image = _F.adjust_saturation | ||
|
||
adjust_contrast_image = _F.adjust_contrast | ||
|
||
adjust_sharpness_image = _F.adjust_sharpness | ||
|
||
posterize_image = _F.posterize | ||
|
||
solarize_image = _F.solarize | ||
|
||
autocontrast_image = _F.autocontrast | ||
|
||
equalize_image = _F.equalize | ||
|
||
invert_image = _F.invert |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why we need this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other option would be for a simple inputs function to return
((...), dict(...))
to bundleargs
andkwargs
. IMO this is not as convenient as having a single structure to hold everything. To meis more readable than
given that it resembles the actual call signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plus, I was looking into lazily loading the samples. This is not implemented yet, so we are generating all samples at test collection time. This can become an issue real quick if we go along this automated tests direction.