-
Notifications
You must be signed in to change notification settings - Fork 7.1k
make transforms v2 get_params a staticmethod #7177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a1d7c28
20878fa
f6021ed
20117d3
63c9908
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -649,37 +649,58 @@ def test_call_consistency(config, args_kwargs): | |
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"config", | ||
[config for config in CONSISTENCY_CONFIGS if hasattr(config.legacy_cls, "get_params")], | ||
ids=lambda config: config.legacy_cls.__name__, | ||
get_params_parametrization = pytest.mark.parametrize( | ||
("config", "get_params_args_kwargs"), | ||
[ | ||
pytest.param( | ||
next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls), | ||
get_params_args_kwargs, | ||
id=transform_cls.__name__, | ||
) | ||
for transform_cls, get_params_args_kwargs in [ | ||
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), | ||
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), | ||
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), | ||
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])), | ||
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), | ||
( | ||
prototype_transforms.RandomAffine, | ||
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]), | ||
), | ||
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), | ||
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), | ||
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])), | ||
(prototype_transforms.AutoAugment, ArgsKwargs(5)), | ||
] | ||
], | ||
) | ||
def test_get_params_alias(config): | ||
|
||
|
||
@get_paramsl_parametrization | ||
def test_get_params_alias(config, get_params_args_kwargs): | ||
assert config.prototype_cls.get_params is config.legacy_cls.get_params | ||
|
||
if not config.args_kwargs: | ||
return | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
args, kwargs = config.args_kwargs[0] | ||
legacy_transform = config.legacy_cls(*args, **kwargs) | ||
prototype_transform = config.prototype_cls(*args, **kwargs) | ||
|
||
@pytest.mark.parametrize( | ||
("transform_cls", "args_kwargs"), | ||
[ | ||
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), | ||
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), | ||
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), | ||
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])), | ||
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), | ||
( | ||
prototype_transforms.RandomAffine, | ||
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]), | ||
), | ||
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), | ||
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), | ||
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])), | ||
(prototype_transforms.AutoAugment, ArgsKwargs(5)), | ||
], | ||
) | ||
def test_get_params_jit(transform_cls, args_kwargs): | ||
args, kwargs = args_kwargs | ||
assert prototype_transform.get_params is legacy_transform.get_params | ||
|
||
|
||
@get_paramsl_parametrization | ||
def test_get_params_jit(config, get_params_args_kwargs): | ||
get_params_args, get_params_kwargs = get_params_args_kwargs | ||
|
||
torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs) | ||
|
||
if not config.args_kwargs: | ||
return | ||
args, kwargs = config.args_kwargs[0] | ||
transform = config.prototype_cls(*args, **kwargs) | ||
Comment on lines
+693
to
+701
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
torch.jit.script(transform_cls.get_params)(*args, **kwargs) | ||
torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to use the ArgsKwargs class at all here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define "need". You can also replace it with a sequence and a mapping, but I feel this is more convenient, since it closely matches how you would actually call the method. Plus, this is consistent with this module as a whole, which uses
ArgsKwargs
every time we need to define call arguments.