-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Simplify dimension checks on functional_tensor.py #3159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
FYI I'm not sure decorators are supported in torchscript |
The following works import torch
print(torch.__version__)
import typing
def check_tensor(func: typing.Callable) -> typing.Callable:
def wrapper(t: torch.Tensor, p: int = 5):
if not isinstance(t, torch.Tensor):
raise TypeError()
print("Inside decorator")
return func(t, p)
return wrapper
@check_tensor
def foo(t: torch.Tensor, p: int = 5) -> torch.Tensor:
for i in range(10):
if i < p:
continue
t += i
return t
print(foo(torch.tensor([0.0])))
sfoo = torch.jit.script(foo)
print(sfoo(torch.tensor([0.0])))
> 1.8.0.dev20201211
> Inside decorator
> tensor([35.])
> Inside decorator
> tensor([35.]) |
Great! |
Somebody working on this feature? |
@avijit9 Not at the moment. If you are interested in sending a PR that would be awesome! |
How to deal with multiple arguments? For example, the crop function takes multiple arguments. But if we use a decorator with
The decorator I implemented is shown below -
The unit test is as follows -
This is a known issue - pytorch/pytorch#29637 cc - @datumbox |
Minimal code to reproduce this - import torch
print(torch.__version__)
import typing
def check_tensor(func: typing.Callable) -> typing.Callable:
def wrapper(*args, **kwargs):
if not isinstance(args[0], torch.Tensor):
raise TypeError()
print("Inside decorator")
return func(*args, **kwargs)
return wrapper
@check_tensor
def foo(t: torch.Tensor, p: int = 5) -> torch.Tensor:
for i in range(10):
if i < p:
continue
t += i
return t
> print(foo(torch.tensor([0.0])))
> sfoo = torch.jit.script(foo)
> print(sfoo(torch.tensor([0.0]))) |
@avijit9 Thank you for the detailed analysis! @vfdev-5 Do you have any suggestion to work around the limitation? I suppose it's possible to implement a couple of versions of the decorator since most of the methods in functional_tensor.py have limited number of arguments but I think that would lead to more messy code which defeats the purpose of your original recommendation. Thoughts? |
@datumbox yes, seems like it is a blocker for decorator usage. On the other hand, the goal of this issue is to refactor a bit the current code base by removing explicit type checking like if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') Using decorator we remove 2 lines and add 1. I propose to just create a simple helper methods and reuse them with the similar number of lines gain: def _assert_image_tensor(img):
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
"""PRIVATE METHOD. Crop the given Image Tensor.
"""
_assert_image_tensor(img)
return img[..., top:top + height, left:left + width] What do you think ? |
@datumbox Sure! I'll send a PR soon. |
* added the helper method for dimension checks * unit tests for dimensio check function in functional_tensor * code formatting and typing * moved torch image check after tensor check * unit testcases for test_assert_image_tensor added and refactored * separate unit testcase file deleted * assert_image_tensor added to newly created 6 methods * test cases added for new 6 mthohds * removed wrongly pasted posterize method and added solarize method for testing Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
Summary: * added the helper method for dimension checks * unit tests for dimensio check function in functional_tensor * code formatting and typing * moved torch image check after tensor check * unit testcases for test_assert_image_tensor added and refactored * separate unit testcase file deleted * assert_image_tensor added to newly created 6 methods * test cases added for new 6 mthohds * removed wrongly pasted posterize method and added solarize method for testing Reviewed By: fmassa Differential Revision: D25679214 fbshipit-source-id: 60ca5c1e6a653195a3dd07755b7ac7fa6d4eaf4b Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
🚀 Feature
The functional_tensor.py file makes use of the private method
_is_tensor_a_torch_image
in every public operator to check its dimensions:vision/torchvision/transforms/functional_tensor.py
Lines 10 to 11 in dab4757
Examples:
vision/torchvision/transforms/functional_tensor.py
Lines 146 to 147 in dab4757
vision/torchvision/transforms/functional_tensor.py
Lines 166 to 167 in dab4757
This check is repetitive and reduces the code readability. We should fix this by using
decorators. See #3123 (comment) for details.assertions.cc @vfdev-5
The text was updated successfully, but these errors were encountered: