-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[POC] Potential ways for making Transforms V2 classes JIT-scriptable #6711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@datumbox I think this wrapper approach is quite similar with how |
Depending on how long we want to support scriptability, this might be pretty easy. Assuming that we only support it until v2 is stable and deprecate and remove it together with v1 afterwards, we can simply do: class MyTransformV2(transforms.Transform):
def __init__(self, foo, bar):
super().__init__()
self.foo = foo
self.bar = bar
...
def __prepare_scriptable__(self):
# This hook is called early by `torch.jit.script`. See
# https://github.com/pytorch/pytorch/blob/a6ac922eabee8fce7a48dedac81e82ac8cfe9a45/torch/jit/_script.py#L1284-L1288
# https://github.com/pytorch/pytorch/blob/a6ac922eabee8fce7a48dedac81e82ac8cfe9a45/torch/jit/_script.py#L982
# If this method exists, its return value is used over the original object for scripting.
return MyTransformV1(self.foo, self.bar) If we want to support scriptability for longer, we can still use the hook, but need to return something custom. One option is of course here to copy-paste the important v1 code in there. I'm currently exploring using this hook together with AST rewriting proposed in the top comment to automate that. |
Thanks Philip, that looks very promising and could potentially remove a major roadblock toward migration to v2. As discussed offline, if the |
The name has the same energy as the |
Closed in #7135. |
🚀 The feature
Note: This is an exploratory proof-of-concept to discuss potential workarounds for offering limited support of JIT in our Transforms V2 Classes. I am NOT advocating for following this approach. I'm hoping we can kick off the discussion for other alternative and simpler approaches.
Currently the Transforms V2 classes are not JIT-scriptable. This breaks BC and will make the rollout of the new API harder. Here are some of the choices that are incompatible with JIT:
for ... else
)Points 3 & 4 could be addressed by (painful) refactoring, nevertheless points 1 & 2 are our main blockers.
To ensure our users can still do inference using JIT, we offer presets/transforms attached to each model weights. Those will remain JIT-scriptable. In addition, we applied a workaround (#6553) to maintain the
F
dispatcher JIT-scriptable for plainTensors
. Hopefully these mitigations will help most users migrate easier to the new API.But what if they don't? Many downstream users might want to continue relying on transforms such as
Resize
,CenterCrop
,Pad
etc for inference. In that case, one option could be to offer JIT-scriptable alternatives that work only for pure tensors. Another alternative is to write a utility that can modify the existing implementations on-the-fly to update key functions and make them JIT-scriptable.Motivation, pitch
This is a proof-of-concept of how such a utility can work. It only supports a handful of transforms (due to points 3 & 4 from above) but it can be extended to support more.
There are 2 approaches show-cased below:
ast
to replace on-the-fly problematic idioms from the Transform classes. Since JIT also usesast
internally, we need to make the updated code available to JIT during scripting.forward()
to remove the packing/unpacking of arbitrary number of inputs. We also hardcode plain tensors as the only accepted input type.The above works on our latest main without modifications:
This approach can currently only support a handful of simple Transforms, that don't require overwriting the
forward()
and that contain most of their logic inside their_get_params()
and_transform()
methods. Many such simple transforms are still not supported because they inherit from_RandomApplyTransform
which does the random call in its forward (this could be refactored to move to_get_params()
). The rest of the existing inference transforms can be supported by addressing points 3 & 4 from above.The above approach is very over-engineered, brittle and opaque because it tries to fix the JIT-scriptability issues without any modifications on the code-base for the selected example. If we accept minor refactoring on the existing classes, we can remove the
ast
logic. We could also avoid defining a default JIT-compatible forward by explicitly defining such a method on the original class when available. Here is one potential simplified version that would require changes on our current API:Alternatives
There are several other alternatives we could follow. One of them could be to offer JIT-scriptable versions for a limited number of Transforms that are commonly used during inference. Another one could be to make some of our transforms FX-traceable instead of JIT-scriptable. Though not all classes can become traceable (because their behaviour branches based on the input), considering making them compatible will future proof us for PyTorch 2.
Additional context
No response
cc @vfdev-5 @bjuncek @pmeier
The text was updated successfully, but these errors were encountered: