Skip to content

Make prototype F JIT-scriptable #6584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Sep 20, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Sep 14, 2022

Fixes #6553

  • Adds limited JIT tests for mid-level kernels
  • Makes mid-level kernels JIT-scriptable:
    • augment
    • colour
    • deprecated
    • geometry
    • meta
    • misc
    • type conversion

co-authored with @vfdev-5

Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some early notes:

@@ -32,6 +32,31 @@ def from_pil_mode(cls, mode: str) -> ColorSpace:
else:
return cls.OTHER

@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes more sense to be on the Enum than the method. At any case it was moved because JIT didn't like it to be a class method. We don't have to keep it in the enum, we just need to have it as a standalone method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can't even have classmethods on objects that are not covered by JIT? 🤯

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, I intentionally put the term "guess" in the name, since the number of channels is not sufficient to pick the right colorspace. For example, CMYK also has 4 channels, but would be classified as RGBA. However, this is not a problem now since we don't support it yet and maybe never will.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue is that the specific class is definitely not JIT-scriptable. JIT complains about the tensor_contents: Any = None value on the __repr__. Perhaps we can remove this?

I'm happy to make any changes on the name. OR try a different approach with JIT. Let me wrap up the rest of the kernels to see where we are and we can try a couple of options.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it works as is, I wouldn't mess with this class any further. If it is needed we can remove the : Any though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier I'm also really flexible. I don't have context why the tensor_contents is introduced and what's supposed to be. Would you like to send a PR once this is merged to see if you can make it work in the original location?

Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments inline. Otherwise LGTM if CI is green.

@@ -32,6 +32,31 @@ def from_pil_mode(cls, mode: str) -> ColorSpace:
else:
return cls.OTHER

@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we can't even have classmethods on objects that are not covered by JIT? 🤯

@@ -32,6 +32,31 @@ def from_pil_mode(cls, mode: str) -> ColorSpace:
else:
return cls.OTHER

@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, I intentionally put the term "guess" in the name, since the number of channels is not sufficient to pick the right colorspace. For example, CMYK also has 4 channels, but would be classified as RGBA. However, this is not a problem now since we don't support it yet and maybe never will.

@datumbox datumbox requested a review from pmeier September 16, 2022 15:24
@datumbox datumbox force-pushed the jit/prototype_transforms branch from ef7130c to 60000fa Compare September 16, 2022 17:26
@datumbox datumbox force-pushed the jit/prototype_transforms branch from aafb356 to 1e301b1 Compare September 17, 2022 08:39
@vfdev-5 vfdev-5 marked this pull request as ready for review September 19, 2022 15:15
@vfdev-5 vfdev-5 changed the title [WIP] Make prototype F JIT-scriptable Make prototype F JIT-scriptable Sep 19, 2022
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 19, 2022

@datumbox I let you review your PR once you are back :)

Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 I might not be the best person to "review" this PR but I can confirm that any modification done on your side looks great. It was a good idea to do most of the work needed on fill on separate PRs. I had a look also at your PRs #6595 and #6599 and the direction looks great. Maintaining F JIT-scriptable is a major win, thanks a lot for your work on making it possible.

@pmeier Could you please give us an independent review on this PR?

FYI: follow up work we should do is ensuring the signatures of F and _Feature are aligned. After merging this PR, the signatures of the dispatchers on the base class don't match (one uses Sequences and the other uses List).

@@ -72,11 +72,10 @@ def _apply_image_transform(

# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we have to put fill as None if fill == 0
fill_: Optional[Union[int, float, Sequence[int], Sequence[float]]]
# This is due to BC with stable API which has fill = None by default
fill_ = F._geometry._convert_fill_arg(fill)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to consider making _convert_fill_arg() public on the future as this is regularly used in Transforms as a utility method.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 20, 2022

FYI: follow up work we should do is ensuring the signatures of F and _Feature are aligned.

@datumbox should they match ? F is now jit scriptable and _Feature is not, so the latter can also accept Sequence[int] or Sequence[float] for user API simplicity:

Image.pad(..., fill=(1, 2, 3))
# instead of only List[float]
Image.pad(..., fill=[1.0, 2.0, 3.0])

Do we want that ? If yes, I can update this PR and put all _Feature fill typehint as Optional[Union[int, float, List[float]]] = None

@datumbox
Copy link
Contributor Author

datumbox commented Sep 20, 2022

@vfdev-5 One of the original proposals was that the _Feature class and F will have methods with aligned signatures. The first would receive the maximal possible parameters (even those not used) and do a simple redirection to the low-level kernel. The only reason I proposed to change them is to ensure this is uphold. We can certainly do this on a follow up PR where we confirm that all APIs are aligned overall. What do you think?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 20, 2022

@datumbox OK to me. @pmeier please review and we merge this PR and I'll update _Feature signatures in a follow-up PR

Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly questions from my side. LGTM if CI is green. Thanks @datumbox and @vfdev-5!

# Enum<__torch__.torchvision.transforms.functional.InterpolationMode> interpolation=Enum<InterpolationMode.BILINEAR>,
# Union(float[], float, int, NoneType) fill=None) -> Tensor
#
# This is probably due to the fact that F.perspective does not have the same signature as F.perspective_image_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if the signatures actually diverge, then this is the issue? I'm ok disabling this here for now and look at it later.

Still, is this something we want? Shouldn't the public kernels be in sync with the dispatcher?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe @vfdev-5 made this change. As discussed previously there are other places where we need to align the signatures and this can happen on a follow up PR to avoid making this too long.

Personally I think it's worth aligning the signatures unless there is a good reason not to (perhaps the interpolation default value is one exception?). I'm open to discussing this and I think we should agree on the policy soon.

@@ -9,7 +9,6 @@


class TestCommon:
@pytest.mark.xfail(reason="dispatchers are currently not scriptable")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

Comment on lines +476 to +477
def test_scriptable_midlevel(midlevel):
jit.script(midlevel) # TODO: pass data through it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 Isn't that superseded by the tests I've added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? I noticed that at one point early in the development, one of the kernels had a return type Any and this test didn't complaint. This is why usually to be safe, we take the strategy of:

  1. JIT-scripting
  2. Passing data through the kernel and compare it with the non-JIT version
  3. Serialize/deserialize the method and confirm it still returns the right value.

See _check_jit_scriptable() from test_models.py for more info.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox in test/test_prototype_transforms_dispatchers.py the test should script and execute midlevel op, so I think we can remove this test_scriptable_midlevel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll cleanup.

@@ -32,6 +32,31 @@ def from_pil_mode(cls, mode: str) -> ColorSpace:
else:
return cls.OTHER

@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it works as is, I wouldn't mess with this class any further. If it is needed we can remove the : Any though.

@@ -252,7 +252,14 @@ def __init__(

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)

# This cast does Sequence[int] -> List[int] and is required to make mypy happy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, mypy is not at fault here. The kernel could support Sequence[int], but we annotated List[int]. Since that is more specific, mypy is right to complain. We can always silence mypy with a # type: ignore[...] comment or use cast(...) since both have no runtime implications. Actually casting to a list while it is not needed hurts performance although probably not by a significant amount.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point where we try to improve the speed by any margin, I would prefer using ignore statements rather than casting. I don't know how negligible the introduction of cast() is but if there is no measurable slow-down I don't mind switching.

@datumbox datumbox merged commit 841b9a1 into pytorch:main Sep 20, 2022
@datumbox datumbox deleted the jit/prototype_transforms branch September 20, 2022 11:29
facebook-github-bot pushed a commit that referenced this pull request Sep 23, 2022
Summary:
* Improve existing low kernel test.

* Add new midlevel jit-scriptability test (failing).

* Remove duplicate aliases from kernel tests.

* Fixing colour kernels.

* Fixing deprecated kernels.

* fix mypy

* Silence mypy instead of fixing to avoid performance penalty

* Fixing augment kernels.

* Fixing augment meta.

* Remove is_tracing calls.

* Add fake ImageType and DType

* Fixing type conversion kernels.

* Fixing misc kernels.

* partial fix geometry

* Remove mutable default from `_pad_with_vector_fill()` + all other unnecessary defaults.

* Fix geometry ops

* Fixing tests

* Removed xfail for jit tests on midlevel ops

Reviewed By: NicolasHug

Differential Revision: D39765297

fbshipit-source-id: 50ec9dc9d9e2f9c8dab6ab01337e01643dc0ab64

Co-authored-by: vfdev-5 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RFC: Make prototype F jit-scriptable again?
4 participants