Skip to content

enable get_params alias for transforms v2 #7153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,39 @@ def test_call_consistency(config, args_kwargs):
)


@pytest.mark.parametrize(
"config",
[config for config in CONSISTENCY_CONFIGS if hasattr(config.legacy_cls, "get_params")],
ids=lambda config: config.legacy_cls.__name__,
)
def test_get_params_alias(config):
assert config.prototype_cls.get_params is config.legacy_cls.get_params


@pytest.mark.parametrize(
("transform_cls", "args_kwargs"),
[
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(
prototype_transforms.RandomAffine,
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
),
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(prototype_transforms.AutoAugment, ArgsKwargs(5)),
],
)
def test_get_params_jit(transform_cls, args_kwargs):
args, kwargs = args_kwargs

torch.jit.script(transform_cls.get_params)(*args, **kwargs)


@pytest.mark.parametrize(
("config", "args_kwargs"),
[
Expand Down
13 changes: 11 additions & 2 deletions torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,19 @@ def extra_repr(self) -> str:

return ", ".join(extra)

# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables the v2 transformation
# to be scriptable. See `_extract_params_for_v1_transform()` and `__prepare_scriptable__` for details.
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things:
# 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on
# the v2 transform. See `__init_subclass__` for details.
# 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__`
# for details.
_v1_transform_cls: Optional[Type[nn.Module]] = None

def __init_subclass__(cls) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative to __init_subclass__() (of which I keep forgetting the existence and the purpose), would this work too?

@staticmethod
def get_params(cls):
   if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
        return cls._v1_transform_cls.get_params()
    else:
        raise AttributeError(
            "cls {cls} has no get_params method. You probably don't need one anymore
             as the same RNG is applied to all images, bboxes and masks in the same transform call.
             If what you need is a way to transform different batches with the same RNG,
             please reach out at #1234567 (the feedback issue.
         ")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two problems here with JIT:

  1. get_params takes parameters in v1, e.g.

    @staticmethod
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:

    Since I'm not aware of an arbitrary parameter passthrough like (*args: Any, **kwargs: Any) that works with JIT, we would have to implement this manually on every class that needs it.

  2. You used the @staticmethod decorator, but take cls as first input. Since you need cls, we need to switch to @classmethod to make it work. That should work in eager mode, but for JIT this now means that the whole class needs to be scriptable. For example, manually defining get_params for RandomCrop as explained in 1. and trying to script it, leads to

    RuntimeError: 
    'Tensor (inferred)' object has no attribute or method '_v1_transform_cls'.:
      File "/home/philip/git/pytorch/vision/torchvision/prototype/transforms/_geometry.py", line 437
        @classmethod
        def get_params(cls, img: torch.Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
            return cls._v1_transform_cls.get_params(img, output_size)
                   ~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    

    It seems it can't infer the type of cls and uses torch.Tensor as fallback. Annotation cls: Type[RandomCrop] yields

    RuntimeError: 
    Unknown type constructor Type:
    

I agree using __init_subclass__ is unconventional, but it seems like the cleanest solution here. Since we actually alias the function, we avoid all of the JIT crazyness that we would have to deal with otherwise. If you can find a working solution, I'm happy to adopt it though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the details. My last try would be to declare get_params as a @property but... let me guess... JIT doesn't support it?

Copy link
Collaborator Author

@pmeier pmeier Feb 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You guessed right. But even if did, it would wouldn't work for us here. @property needs an instance.

# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
cls.get_params = cls._v1_transform_cls.get_params # type: ignore[attr-defined]

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
Expand Down