-
Notifications
You must be signed in to change notification settings - Fork 76
RFC-0001: Add method __torch_function__ RFC. #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
d0ca416
to
235eb50
Compare
235eb50
to
0e33cf2
Compare
I would have expected a discussion of backwards compatibility on this proposal, both with |
RFC 0001.md
Outdated
|
||
```python | ||
class SubTensor(torch.Tensor): | ||
def __torch_tensor__(self, func, types, args, kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't say what the old type of __torch_tensor__
was so I can't tell what the difference is.
@ngoldbaum do you remember why we didn't line up the type exactly with Numpy's type in the beginning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC that was me. Not 100% sure, but I believe I removed (or never added) types
because it was extra complexity to handle subclasses, and Tensor
subclasses aren't really a thing today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No reason afaik. That API got hammered out before I started working on the feature so I don’t know why the API is different. It wouldn’t be terribly hard to add types
to the signature if we wanted to do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's safe to say we should do this sooner rather than later, even if other parts of this RFC change in their design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I wrote over in pytorch/pytorch#30730 (comment), my motivation for the types
argument in NumPy's __array_function__
was never about subclasses (which I agree are generally awful). My concern was making it as easy and idiomatic as possible for implementers of __array_function__
to defer to unrecognized types that might also implement an operation. This is generally a good thing for the ecosystem, but people writing special methods tend not to bother, at least for builtin numeric protocols like __add__
:).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was never about subclasses (which I agree are generally awful)
I kind of disagree. numpy.ndarray
subclasses give us a bad taste in the mouth because of the history of numpy.matrix
and other badly written subclasses. But in principle they make sense. Also for PyTorch there's a lot of interest, and the use cases are solid - and letting those users write a whole torch.Tensor
-like object is definitely not feasible in most cases.
(although doing so would break any downstream users who are out there) |
Procedural note: let's put a more descriptive filename on RFCs, so we can easily tell what's what :) |
RFC 0001.md
Outdated
PyTorch `master` pointed to commit hash `957a07ffbd13d8a805f4d718e0282efc5d2bff85` at the time of writing. Any classes implementing `__torch_function__` based on the usage in this commit hash will break completely, due to the differing signature of the protocol. However, as a release hasn't been made with `__torch_function__` in it, this is a minor-impact issue. | ||
|
||
### With NumPy | ||
As we are using a different protocol compared to NumPy `__torch_function__` vs `__array_function__`, there is no difference to the usage for those using NumPy. We propose to delay the issue of allowing the usage of Torch tensors with NumPy functions to a separate RFC. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I wasn't clear in my earlier comment. My question is not allowing PyTorch tensors be used with Numpy functions (although this is an interesting question to pose), but if I am a Numpy user who is familiar with the __array_function__
API, and I come to PyTorch expecting __torch_function__
to work the same way, will my expectations be surprised in any way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there are more than a few handful of users that have built up an intuition here since it's so new, nor do I think including methods will be particularly surprising.
b41e525
to
83d42a6
Compare
RFC 0001 — `__torch_function__` for methods of the `torch.Tensor` class.md
Outdated
Show resolved
Hide resolved
b3536f7
to
9de4faa
Compare
Umm, the new filename is better, but can we avoid putting hard to quote characters in the filename? :> Preferably no spaces either. |
d3d8f15
to
fa73d2b
Compare
fa73d2b
to
686e4cf
Compare
This is very interesting, thanks for the RFC. Am I right to assume this could go a long way in helping enable something like a |
@Balandat I cannot comment more broadly, but certainly, the code example you show in the issue description should be possible, yes. With my understanding of the problem it should all be possible. |
As discussed in pytorch/pytorch#34369.
Summary: This PR adds the `types` argument to `__torch_function__` as per RFC 0001: pytorch/rfcs#3 Pull Request resolved: #34303 Differential Revision: D20474992 Pulled By: ezyang fbshipit-source-id: cdd40b3b38f3bda4ece8812a629f5db87e919d01
Summary: This is according to pytorch/rfcs#3. Pull Request resolved: #34369 Differential Revision: D20963929 Pulled By: ezyang fbshipit-source-id: e618af6fd36e1dfaeda617162314ad5840f55358
Summary: This is according to pytorch/rfcs#3. Pull Request resolved: pytorch#34369 Differential Revision: D20963929 Pulled By: ezyang fbshipit-source-id: e618af6fd36e1dfaeda617162314ad5840f55358
Summary: According to pytorch/rfcs#3 From the goals in the RFC: 1. Support subclassing `torch.Tensor` in Python (done here) 2. Preserve `torch.Tensor` subclasses when calling `torch` functions on them (done here) 3. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `torch.Tensor` subclasses (done in #30730) 4. Preserve `torch.Tensor` subclasses when calling `torch.Tensor` methods. (done here) 5. Propagating subclass instances correctly also with operators, using views/slices/indexing/etc. (done here) 6. Preserve subclass attributes when using methods or views/slices/indexing. (done here) 7. A way to insert code that operates on both functions and methods uniformly (so we can write a single function that overrides all operators). (done here) 8. The ability to give external libraries a way to also define functions/methods that follow the `__torch_function__` protocol. (will be addressed in a separate PR) This PR makes the following changes: 1. Adds the `self` argument to the arg parser. 2. Dispatches on `self` as well if `self` is not `nullptr`. 3. Adds a `torch._C.DisableTorchFunction` context manager to disable `__torch_function__`. 4. Adds a `torch::torch_function_enabled()` and `torch._C._torch_function_enabled()` to check the state of `__torch_function__`. 5. Dispatches all `torch._C.TensorBase` and `torch.Tensor` methods via `__torch_function__`. TODO: - [x] Sequence Methods - [x] Docs - [x] Tests Closes #28361 Benchmarks in #37091 (comment) Pull Request resolved: #37091 Reviewed By: ngimel Differential Revision: D22765678 Pulled By: ezyang fbshipit-source-id: 53f8aa17ddb8b1108c0997f6a7aa13cb5be73de0
@hameerabbasi now that the code landed, can you make final updates to this RFC to adjust to reality? Then let's merge it. |
* Document that __torch_function__ may get methods passed to it even for non-subclasses. * Document __getattr__ idiom. * Add the double inheritance hierarchy diagram to the docs. * Explain how to have a fallback route for things that you don’t explicitly override for subclasses. * Explain how to override single methods vs have a global hook.
Okay, let's finally get this in! Thanks @hameerabbasi and @ezyang |
Has this RFC been implemented? Related: pytorch/pytorch#52265 |
Yup! |
No description provided.