From 40557eb6479773ca7e3e82199120a6f79c1b4ba9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 Oct 2022 09:24:53 +0200 Subject: [PATCH 1/2] enable arbitrary batch size for all prototype kernels --- test/prototype_transforms_dispatcher_infos.py | 9 -- test/prototype_transforms_kernel_infos.py | 11 --- .../transforms/functional/_geometry.py | 84 ++++++++----------- .../prototype/transforms/functional/_misc.py | 38 ++++----- 4 files changed, 51 insertions(+), 91 deletions(-) diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index be8bd3002c1..de933c7e3fa 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -138,12 +138,6 @@ def xfail_all_tests(*, reason, condition): ] -xfails_degenerate_or_multi_batch_dims = xfail_all_tests( - reason="See https://github.com/pytorch/vision/issues/6670 for details.", - condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]), -) - - DISPATCHER_INFOS = [ DispatcherInfo( F.horizontal_flip, @@ -260,7 +254,6 @@ def xfail_all_tests(*, reason, condition): pil_kernel_info=PILKernelInfo(F.perspective_image_pil), test_marks=[ xfail_dispatch_pil_if_fill_sequence_needs_broadcast, - *xfails_degenerate_or_multi_batch_dims, ], ), DispatcherInfo( @@ -271,7 +264,6 @@ def xfail_all_tests(*, reason, condition): features.Mask: F.elastic_mask, }, pil_kernel_info=PILKernelInfo(F.elastic_image_pil), - test_marks=xfails_degenerate_or_multi_batch_dims, ), DispatcherInfo( F.center_crop, @@ -294,7 +286,6 @@ def xfail_all_tests(*, reason, condition): test_marks=[ xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("sigma"), - *xfails_degenerate_or_multi_batch_dims, ], ), DispatcherInfo( diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index d90d3bf68be..9ebfc7a00d2 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -156,12 +156,6 @@ def xfail_all_tests(*, reason, condition): ] -xfails_image_degenerate_or_multi_batch_dims = xfail_all_tests( - reason="See https://github.com/pytorch/vision/issues/6670 for details.", - condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]), -) - - KERNEL_INFOS = [] @@ -1156,7 +1150,6 @@ def sample_inputs_perspective_video(): reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_inputs_fn=reference_inputs_perspective_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.perspective_bounding_box, @@ -1168,7 +1161,6 @@ def sample_inputs_perspective_video(): reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_inputs_fn=reference_inputs_perspective_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.perspective_video, @@ -1239,7 +1231,6 @@ def sample_inputs_elastic_video(): reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_inputs_fn=reference_inputs_elastic_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.elastic_bounding_box, @@ -1251,7 +1242,6 @@ def sample_inputs_elastic_video(): reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_inputs_fn=reference_inputs_elastic_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.elastic_video, @@ -1379,7 +1369,6 @@ def sample_inputs_gaussian_blur_video(): test_marks=[ xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("sigma"), - *xfails_image_degenerate_or_multi_batch_dims, ], ), KernelInfo( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 670b2cb87b8..55d24c56a44 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -882,7 +882,23 @@ def perspective_image_tensor( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: features.FillTypeJIT = None, ) -> torch.Tensor: - return _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill) + if image.numel() == 0: + return image + + shape = image.shape + + if image.ndim > 4: + image = image.view((-1,) + shape[-3:]) + needs_unsquash = True + else: + needs_unsquash = False + + output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill) + + if needs_unsquash: + output = output.view(shape) + + return output @torch.jit.unused @@ -1007,20 +1023,27 @@ def perspective_video( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: features.FillTypeJIT = None, ) -> torch.Tensor: - # TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when - # https://github.com/pytorch/vision/issues/6670 is resolved. - if video.numel() == 0: - return video + return perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill) - shape = video.shape - if video.ndim > 4: - video = video.view((-1,) + shape[-3:]) +def elastic_image_tensor( + image: torch.Tensor, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: features.FillTypeJIT = None, +) -> torch.Tensor: + if image.numel() == 0: + return image + + shape = image.shape + + if image.ndim > 4: + image = image.view((-1,) + shape[-3:]) needs_unsquash = True else: needs_unsquash = False - output = perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill) + output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill) if needs_unsquash: output = output.view(shape) @@ -1028,29 +1051,6 @@ def perspective_video( return output -def perspective( - inpt: features.InputTypeJIT, - perspective_coeffs: List[float], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: features.FillTypeJIT = None, -) -> features.InputTypeJIT: - if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): - return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) - elif isinstance(inpt, features._Feature): - return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill) - else: - return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) - - -def elastic_image_tensor( - image: torch.Tensor, - displacement: torch.Tensor, - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: features.FillTypeJIT = None, -) -> torch.Tensor: - return _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill) - - @torch.jit.unused def elastic_image_pil( image: PIL.Image.Image, @@ -1128,25 +1128,7 @@ def elastic_video( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: features.FillTypeJIT = None, ) -> torch.Tensor: - # TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when - # https://github.com/pytorch/vision/issues/6670 is resolved. - if video.numel() == 0: - return video - - shape = video.shape - - if video.ndim > 4: - video = video.view((-1,) + shape[-3:]) - needs_unsquash = True - else: - needs_unsquash = False - - output = elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) - - if needs_unsquash: - output = output.view(shape) - - return output + return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) def elastic( diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 7b3773e63a1..79a358b4ed5 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -56,7 +56,23 @@ def gaussian_blur_image_tensor( if s <= 0.0: raise ValueError(f"sigma should have positive values. Got {sigma}") - return _FT.gaussian_blur(image, kernel_size, sigma) + if image.numel() == 0: + return image + + shape = image.shape + + if image.ndim > 4: + image = image.view((-1,) + shape[-3:]) + needs_unsquash = True + else: + needs_unsquash = False + + output = _FT.gaussian_blur(image, kernel_size, sigma) + + if needs_unsquash: + output = output.view(shape) + + return output @torch.jit.unused @@ -71,25 +87,7 @@ def gaussian_blur_image_pil( def gaussian_blur_video( video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: - # TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when - # https://github.com/pytorch/vision/issues/6670 is resolved. - if video.numel() == 0: - return video - - shape = video.shape - - if video.ndim > 4: - video = video.view((-1,) + shape[-3:]) - needs_unsquash = True - else: - needs_unsquash = False - - output = gaussian_blur_image_tensor(video, kernel_size, sigma) - - if needs_unsquash: - output = output.view(shape) - - return output + return gaussian_blur_image_tensor(video, kernel_size, sigma) def gaussian_blur( From 501bbf5579b78c40ed52916ce4fc61518b41126c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 Oct 2022 10:40:04 +0200 Subject: [PATCH 2/2] put back perspective dispatcher --- .../prototype/transforms/functional/_geometry.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 55d24c56a44..2c064245e8a 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1026,6 +1026,20 @@ def perspective_video( return perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill) +def perspective( + inpt: features.InputTypeJIT, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: features.FillTypeJIT = None, +) -> features.InputTypeJIT: + if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): + return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) + elif isinstance(inpt, features._Feature): + return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill) + else: + return perspective_image_pil(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) + + def elastic_image_tensor( image: torch.Tensor, displacement: torch.Tensor,