Skip to content

make transforms v2 get_params a staticmethod #7177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 47 additions & 26 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,37 +649,58 @@ def test_call_consistency(config, args_kwargs):
)


@pytest.mark.parametrize(
"config",
[config for config in CONSISTENCY_CONFIGS if hasattr(config.legacy_cls, "get_params")],
ids=lambda config: config.legacy_cls.__name__,
get_params_parametrization = pytest.mark.parametrize(
("config", "get_params_args_kwargs"),
[
pytest.param(
next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
get_params_args_kwargs,
id=transform_cls.__name__,
)
for transform_cls, get_params_args_kwargs in [
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use the ArgsKwargs class at all here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define "need". You can also replace it with a sequence and a mapping, but I feel this is more convenient, since it closely matches how you would actually call the method. Plus, this is consistent with this module as a whole, which uses ArgsKwargs every time we need to define call arguments.

(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(
prototype_transforms.RandomAffine,
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
),
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(prototype_transforms.AutoAugment, ArgsKwargs(5)),
]
],
)
def test_get_params_alias(config):


@get_paramsl_parametrization
def test_get_params_alias(config, get_params_args_kwargs):
assert config.prototype_cls.get_params is config.legacy_cls.get_params

if not config.args_kwargs:
return
args, kwargs = config.args_kwargs[0]
legacy_transform = config.legacy_cls(*args, **kwargs)
prototype_transform = config.prototype_cls(*args, **kwargs)

@pytest.mark.parametrize(
("transform_cls", "args_kwargs"),
[
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(
prototype_transforms.RandomAffine,
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
),
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(prototype_transforms.AutoAugment, ArgsKwargs(5)),
],
)
def test_get_params_jit(transform_cls, args_kwargs):
args, kwargs = args_kwargs
assert prototype_transform.get_params is legacy_transform.get_params


@get_paramsl_parametrization
def test_get_params_jit(config, get_params_args_kwargs):
get_params_args, get_params_kwargs = get_params_args_kwargs

torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)

if not config.args_kwargs:
return
args, kwargs = config.args_kwargs[0]
transform = config.prototype_cls(*args, **kwargs)
Comment on lines +693 to +701
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using get_params_args_kwargs for one of them and config.args_kwargs[0] for the other?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_params_args_kwargs is coming from the parametrization and only includes a single ArgsKwargs to test calling get_params. config.args_kwargs is a list of ArgsKwargs that is used to instantiate the transform and is coming from the CONSISTENCY_CONFIGS. Since get_params is a @staticmethod, we don't need to care about the actual parameters the transform is instantiated with. Thus, we just take the first available to get an arbitrary instance, and call get_params on that.


torch.jit.script(transform_cls.get_params)(*args, **kwargs)
torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init_subclass__(cls) -> None:
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
cls.get_params = cls._v1_transform_cls.get_params # type: ignore[attr-defined]
cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined]

def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
Expand Down