Skip to content

Simplify dimension checks on functional_tensor.py #3159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
datumbox opened this issue Dec 11, 2020 · 11 comments · Fixed by #3171
Closed

Simplify dimension checks on functional_tensor.py #3159

datumbox opened this issue Dec 11, 2020 · 11 comments · Fixed by #3171

Comments

@datumbox
Copy link
Contributor

datumbox commented Dec 11, 2020

🚀 Feature

The functional_tensor.py file makes use of the private method _is_tensor_a_torch_image in every public operator to check its dimensions:

def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2

Examples:

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

This check is repetitive and reduces the code readability. We should fix this by using decorators. See #3123 (comment) for details. assertions.

cc @vfdev-5

@fmassa
Copy link
Member

fmassa commented Dec 11, 2020

FYI I'm not sure decorators are supported in torchscript

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Dec 11, 2020

FYI I'm not sure decorators are supported in torchscript

The following works

import torch
print(torch.__version__)
import typing


def check_tensor(func: typing.Callable) -> typing.Callable:
    
    def wrapper(t: torch.Tensor, p: int = 5):
        if not isinstance(t, torch.Tensor):
            raise TypeError()
        print("Inside decorator")
        return func(t, p)
    
    return wrapper


@check_tensor
def foo(t: torch.Tensor, p: int = 5) -> torch.Tensor:
    for i in range(10):
        if i < p:
            continue
        t += i
    return t


print(foo(torch.tensor([0.0])))
sfoo = torch.jit.script(foo)
print(sfoo(torch.tensor([0.0])))

> 1.8.0.dev20201211
> Inside decorator
> tensor([35.])
> Inside decorator
> tensor([35.])

@fmassa
Copy link
Member

fmassa commented Dec 11, 2020

Great!

@avijit9
Copy link
Contributor

avijit9 commented Dec 12, 2020

Somebody working on this feature?

@datumbox
Copy link
Contributor Author

@avijit9 Not at the moment. If you are interested in sending a PR that would be awesome!

@avijit9
Copy link
Contributor

avijit9 commented Dec 12, 2020

How to deal with multiple arguments? For example, the crop function takes multiple arguments. But if we use a decorator with *args, **kwargs I'm getting the following error -

FAILED test/test_decorator.py::Tester::test_crop - torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:

The decorator I implemented is shown below -

def check_tensor_torch_image(f: typing.Callable) -> typing.Callable:
    def wrapper(*args, **kwargs):
        if args[0].ndim < 2:
            raise TypeError("Tensor is not a torch image.")
        return f(*args, **kwargs)
    return wrapper

The unit test is as follows -

	def test_crop(self):
		scripted_fn = torch.jit.script(F_t.crop)
		shape = (10,)
		tensor = torch.rand(*shape, dtype=torch.float, device=self.device)
		with self.assertRaises(Exception) as context:
			scripted_fn(tensor)
		self.assertTrue('Tensor is not a torch image.' in str(context.exception))

This is a known issue - pytorch/pytorch#29637

cc - @datumbox

@avijit9
Copy link
Contributor

avijit9 commented Dec 14, 2020

Minimal code to reproduce this -

import torch
print(torch.__version__)
import typing


def check_tensor(func: typing.Callable) -> typing.Callable:
    
    def wrapper(*args, **kwargs):
        if not isinstance(args[0], torch.Tensor):
            raise TypeError()
        print("Inside decorator")
        return func(*args, **kwargs)
    
    return wrapper


@check_tensor
def foo(t: torch.Tensor, p: int = 5) -> torch.Tensor:
    for i in range(10):
        if i < p:
            continue
        t += i
    return t


> print(foo(torch.tensor([0.0])))
> sfoo = torch.jit.script(foo)
> print(sfoo(torch.tensor([0.0])))

@datumbox
Copy link
Contributor Author

@avijit9 Thank you for the detailed analysis!

@vfdev-5 Do you have any suggestion to work around the limitation? I suppose it's possible to implement a couple of versions of the decorator since most of the methods in functional_tensor.py have limited number of arguments but I think that would lead to more messy code which defeats the purpose of your original recommendation. Thoughts?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Dec 14, 2020

@datumbox yes, seems like it is a blocker for decorator usage. On the other hand, the goal of this issue is to refactor a bit the current code base by removing explicit type checking like

 if not _is_tensor_a_torch_image(img): 
     raise TypeError('tensor is not a torch image.') 

Using decorator we remove 2 lines and add 1. I propose to just create a simple helper methods and reuse them with the similar number of lines gain:

def _assert_image_tensor(img):
    if not _is_tensor_a_torch_image(img):
        raise TypeError("tensor is not a torch image.")    

def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
    """PRIVATE METHOD. Crop the given Image Tensor.
    """
    _assert_image_tensor(img)

    return img[..., top:top + height, left:left + width]

What do you think ?

@datumbox
Copy link
Contributor Author

@vfdev-5 agreed, that's still an improvement.

@avijit9 would you be still interested to send a PR based on the above proposal?

@avijit9
Copy link
Contributor

avijit9 commented Dec 14, 2020

@datumbox Sure! I'll send a PR soon.

@datumbox datumbox changed the title Use decorators to simplify dimension checks on functional_tensor.py Simplify dimension checks on functional_tensor.py Dec 14, 2020
datumbox added a commit that referenced this issue Dec 15, 2020
* added the helper method for dimension checks

* unit tests for dimensio check function in functional_tensor

* code formatting and typing

* moved torch image check after tensor check

* unit testcases for test_assert_image_tensor added and refactored

* separate unit testcase file deleted

* assert_image_tensor added to newly created 6 methods

* test cases added for new 6 mthohds

* removed wrongly pasted posterize method and added solarize method for testing

Co-authored-by: Vasilis Vryniotis <[email protected]>

Co-authored-by: Vasilis Vryniotis <[email protected]>
facebook-github-bot pushed a commit that referenced this issue Dec 23, 2020
Summary:
* added the helper method for dimension checks

* unit tests for dimensio check function in functional_tensor

* code formatting and typing

* moved torch image check after tensor check

* unit testcases for test_assert_image_tensor added and refactored

* separate unit testcase file deleted

* assert_image_tensor added to newly created 6 methods

* test cases added for new 6 mthohds

* removed wrongly pasted posterize method and added solarize method for testing

Reviewed By: fmassa

Differential Revision: D25679214

fbshipit-source-id: 60ca5c1e6a653195a3dd07755b7ac7fa6d4eaf4b

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants