-
Notifications
You must be signed in to change notification settings - Fork 7.1k
RFC: Make prototype F
jit-scriptable again?
#6553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I would love the expert opinion from @gmagogsfm, @suo or @eellison who know JIT inside-out to ensure this approach will work and won't lead to some weird behaviour. Update: |
Other than this being dead ugly, I don't have any objections to it. Can we at least wrap isinstance(x, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(x, features._Feature)) into a reusable function or does that not work with scripting? Our old transformations also check for
Do we need to include this here as well? |
I'll give it a try. If JIT is OK with it, I'll create a util method.
Yeah we probably should. And we should also add the loggers as well. |
@pmeier FYI I tried moving the lengthy if statement in a utility method but JIT doesn't like it. As far as I recall this is the reason we couldn't put the logging into a utility method either and instead we have this ugly method call. It doesn't even like it if I move Perhaps we can find a way to simplify but for now I'm proceeding without it. Edit: it seems that it doesn't even allow me to move the |
As @vfdev-5 noticed, we probably shouldn't support tracing as many of our transforms are not made to support it. I will thus remove it from the check. |
🚀 The feature
The mid-layer kernels
F
of the new Transforms API are currently not JIT-scriptable because JIT doesn't support Tensor subclassing. This is a major BC-breaking change that potentially could be avoided by refactoring our mid-level kernels.Here is one way of how this could be achieved:
Effectively the above makes the
kernel()
operate exactly like before on the stableF
: if atorch.Tensor
kernel is provided then it is dispatched directly to the_FT
kernel. The assumption is that while JIT-scripting, only image tensors are allowed. As long as the_features.Image
and thetorch.Tensor
kernels are identical (which is true in our implementation), this means that we won't have any discrepancies between the two types. The extension on the new tensor subclasses (Image, BBox, Mask, Label etc) is only available in Python mode.Unfortunately it's not possible to check the exact type of Tensor on run-time and throw the appropriate error if a user passes by accident a BBox or Label during JIT-scripting. But for these types, other checks (such as the dimension checks) will raise errors.
Proof of Concept
The below implementation is a proof of concept. We intentionally make the kernels simulate different functionality in order to test that the dispatches work as expected.
Output:
cc @vfdev-5 @pmeier
The text was updated successfully, but these errors were encountered: