-
Notifications
You must be signed in to change notification settings - Fork 7.1k
make transforms v2 get_params a staticmethod #7177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Philip, LGTM but to be on the safe side, can we also check that calling instance.get_param()
is also jit-scriptable below in test_get_params_jit()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Philip, approving to unblock, although I'm a bit confused about the usage and need for ArgsKwargs
id=transform_cls.__name__, | ||
) | ||
for transform_cls, get_params_args_kwargs in [ | ||
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to use the ArgsKwargs class at all here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define "need". You can also replace it with a sequence and a mapping, but I feel this is more convenient, since it closely matches how you would actually call the method. Plus, this is consistent with this module as a whole, which uses ArgsKwargs
every time we need to define call arguments.
def test_get_params_jit(config, get_params_args_kwargs): | ||
get_params_args, get_params_kwargs = get_params_args_kwargs | ||
|
||
torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs) | ||
|
||
if not config.args_kwargs: | ||
return | ||
args, kwargs = config.args_kwargs[0] | ||
transform = config.prototype_cls(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we using get_params_args_kwargs
for one of them and config.args_kwargs[0]
for the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_params_args_kwargs
is coming from the parametrization and only includes a single ArgsKwargs
to test calling get_params
. config.args_kwargs
is a list of ArgsKwargs
that is used to instantiate the transform and is coming from the CONSISTENCY_CONFIGS
. Since get_params
is a @staticmethod
, we don't need to care about the actual parameters the transform is instantiated with. Thus, we just take the first available to get an arbitrary instance, and call get_params
on that.
Co-authored-by: Nicolas Hug <[email protected]>
Summary: Co-authored-by: Nicolas Hug <[email protected]> Reviewed By: vmoens Differential Revision: D44416264 fbshipit-source-id: b27c2f61cefb598b1b0fe19102cd842a88a9023c
Addresses #7159 (comment). I was "obsessed" with making
get_params
available on the class in #7153 and completely forgot to check whether this works on instances as well.cc @vfdev-5 @bjuncek