Skip to content

RFC: Make prototype F jit-scriptable again? #6553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
datumbox opened this issue Sep 9, 2022 · 5 comments · Fixed by #6584
Closed

RFC: Make prototype F jit-scriptable again? #6553

datumbox opened this issue Sep 9, 2022 · 5 comments · Fixed by #6584

Comments

@datumbox
Copy link
Contributor

datumbox commented Sep 9, 2022

🚀 The feature

The mid-layer kernels F of the new Transforms API are currently not JIT-scriptable because JIT doesn't support Tensor subclassing. This is a major BC-breaking change that potentially could be avoided by refactoring our mid-level kernels.

Here is one way of how this could be achieved:

def kernel(x: torch.Tensor) -> torch.Tensor:
    if isinstance(x, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(x, features._Feature)):
        return _FT.kernel(x)  # image tensor dispatch
    elif isinstance(x, features._Feature):
        return x.kernel()  # _Feature dispatch
    else:
        return _FP.kernel(x)  # PIL dispatch

Effectively the above makes the kernel() operate exactly like before on the stable F: if a torch.Tensor kernel is provided then it is dispatched directly to the _FT kernel. The assumption is that while JIT-scripting, only image tensors are allowed. As long as the _features.Image and the torch.Tensor kernels are identical (which is true in our implementation), this means that we won't have any discrepancies between the two types. The extension on the new tensor subclasses (Image, BBox, Mask, Label etc) is only available in Python mode.

Unfortunately it's not possible to check the exact type of Tensor on run-time and throw the appropriate error if a user passes by accident a BBox or Label during JIT-scripting. But for these types, other checks (such as the dimension checks) will raise errors.

Proof of Concept

The below implementation is a proof of concept. We intentionally make the kernels simulate different functionality in order to test that the dispatches work as expected.

import numpy as np
import torch
from PIL import Image
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import to_pil_image


# Dummy low-level kernels, we only offer for image_* but we could have added for all types
def invertedmeanpix_image_tensor(x: torch.Tensor) -> float:
    return -_FT.invert(x).float().mean().item()  # return negative mean to simulate different functionality


@torch.jit.unused
def invertedmeanpix_image_pil(x: Image.Image) -> float:
    return np.mean(_FP.invert(x)) / 255.0  # return 0-1 scaled mean to simulate different functionality


# Dummy mid-level which fakes input type to make kernel JIT-scriptable
def invertedmeanpix(x: torch.Tensor) -> float:
    if isinstance(x, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(x, features._Feature)):
        return invertedmeanpix_image_tensor(x)
    elif isinstance(x, features._Feature):
        return x.invert().float().mean().item()  # estimate standard mean to differentiate from the other 2 cases
    else:
        return invertedmeanpix_image_pil(x)


# JIT scripted kernel
invertedmeanpix_scripted = torch.jit.script(invertedmeanpix)
print(invertedmeanpix_scripted.code)

# Create dummy data
image_tensor = torch.randint(0, 256, (3, 30, 20), dtype=torch.uint8)
image_feature = features.Image(image_tensor)
image_pil = to_pil_image(image_tensor)

bbox_feature = features.BoundingBox(
    torch.tensor([[12, 13, 19, 18], [1, 15, 8, 19]]),
    format=features.BoundingBoxFormat.XYXY,
    image_size=image_tensor.shape[:-2],
)

label_data = torch.tensor([1, 2])
label_feature = features.Label(label_data)
ohelabel_data = torch.nn.functional.one_hot(label_data, 3)
ohelabel_feature = features.OneHotLabel(ohelabel_data)

segmask_data = torch.zeros(1, 30, 20)
segmask_data[13:19, 12:18] = 1
segmask_data[15:19, 1:8] = 2
segmask_feature = features.SegmentationMask(segmask_data)

detmask_data = torch.zeros(2, 30, 20)
detmask_data[0, 13:19, 12:18] = 1
detmask_data[1, 15:19, 1:8] = 1
detmask_feature = features.SegmentationMask(detmask_data)


# Assertions
value = invertedmeanpix(image_tensor)
assert value < 0
torch.testing.assert_close(invertedmeanpix(image_feature), -value)
torch.testing.assert_close(invertedmeanpix(image_pil), -value / 255.0)
torch.testing.assert_close(invertedmeanpix_scripted(image_tensor), value)
torch.testing.assert_close(invertedmeanpix(label_feature), label_data.float().mean().item())
torch.testing.assert_close(invertedmeanpix(ohelabel_feature), ohelabel_data.float().mean().item())
torch.testing.assert_close(invertedmeanpix(segmask_feature), segmask_data.mean().item())
torch.testing.assert_close(invertedmeanpix(detmask_feature), detmask_data.mean().item())
print("OK")

Output:

def invertedmeanpix(x: Tensor) -> float:
  _0 = __torch__.invertedmeanpix_image_tensor
  x0 = unchecked_cast(Tensor, x)
  x1 = unchecked_cast(Tensor, x0)
  return _0(x1, )

OK

cc @vfdev-5 @pmeier

@datumbox
Copy link
Contributor Author

datumbox commented Sep 9, 2022

I would love the expert opinion from @gmagogsfm, @suo or @eellison who know JIT inside-out to ensure this approach will work and won't lead to some weird behaviour.

Update:
I had an offline chat with Yanan Cao and he confirmed that this approach will work. He also confirmed that there is no way we can detect on JIT-runtime the subtype of the tensor (_Feature, Image, BoundingBox etc) as these are all seen as torch.Tensor by JIT.

@pmeier
Copy link
Collaborator

pmeier commented Sep 12, 2022

Other than this being dead ugly, I don't have any objections to it. Can we at least wrap

isinstance(x, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(x, features._Feature))

into a reusable function or does that not work with scripting?

Our old transformations also check for torch.jit.is_tracing:

if not torch.jit.is_scripting() and not torch.jit.is_tracing():

Do we need to include this here as well?

@datumbox
Copy link
Contributor Author

Can we at least wrap ... into a reusable function or does that not work with scripting?

I'll give it a try. If JIT is OK with it, I'll create a util method.

Do we need to include this here as well?

Yeah we probably should. And we should also add the loggers as well.

@datumbox
Copy link
Contributor Author

datumbox commented Sep 16, 2022

@pmeier FYI I tried moving the lengthy if statement in a utility method but JIT doesn't like it. As far as I recall this is the reason we couldn't put the logging into a utility method either and instead we have this ugly method call. It doesn't even like it if I move torch.jit.is_scripting() or torch.jit.is_tracing() to a separate _is_jit() method.

Perhaps we can find a way to simplify but for now I'm proceeding without it.

Edit: it seems that it doesn't even allow me to move the torch.jit.is_scripting() or torch.jit.is_tracing() into a variable. That's pretty odd.

@datumbox
Copy link
Contributor Author

As @vfdev-5 noticed, we probably shouldn't support tracing as many of our transforms are not made to support it. I will thus remove it from the check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants