-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Clean up prototype transforms before migration #5626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
This comment only applies in case we decide to reinstate the high-level dispatchers. Leaving this here so it doesn't get lost. Since torchscript is statically compiled we need special handling whenever we rely on duck typing. The most notable example are sequences, which mostly means I propose to use "correct" annotations for the low-level kernels:
The only users that are affected by this are JIT users, since everyone else either uses the higher level abstractions and duck typing in the low-level kernels still works. That should clean up our code quite a bit. Somewhat more controversial, I would also remove the type checks from the low-level kernels. Something like if not isinstance(arg, (tuple, list)):
raise ValueError is very restrictive since the kernel would work with any sequence. It should be Footnotes
|
This might no longer be relevant if #6584 is successful. Keeping open in case it fails. We can close later. |
Regardless of #6584, the transforms will not be JIT scriptable. Thus, cleaning them up is still relevant. |
Good point, it's not about the functional, it's also about the |
We discussed this offline and the consensus is that we can have a go at it for the constructors. |
I wrestled with this the last few days and the result is somewhat sobering. Imagine we currently have a transform like import collections.abc
from typing import *
import torch
from torchvision.prototype import transforms
class Transform(transforms.Transform):
def __init__(self, range: List[float]):
super().__init__()
if not (
isinstance(range, collections.abc.Sequence)
and all(isinstance(item, float) for item in range)
and len(range) == 2
and range[0] < range[1]
):
raise TypeError(f"`range` should be a sequence of two increasing float values, but got {range}")
self._dist = torch.distributions.Uniform(range[0], range[1]) This works perfectly fine for As discussed above, Wrong 😞 Unfortunately, Still, this makes it really hard to use the right annotations here while keeping the type checks. Switching the annotation to
The consequence from I currently see three ways forward here:
Thoughts? |
I would go with option 3. IMO the current typing annotations provide sufficient info without introducing complex solutions like 2. |
The transforms in
torchvision.transforms
are fully JIT scriptable. To achieve this we needed a lot of helper methods and unnecessary strict or plain wrong annotations.The revamped transforms in
torchvision.prototype.transforms
will no longer be JIT scriptable due to their ability to handle more complex inputs. Thus, we should clean-up our code removing everything that was just added to appeasetorch.jit.script
. This should only happen as one of the last steps before migration to the main area.cc @vfdev-5 @datumbox @bjuncek
The text was updated successfully, but these errors were encountered: