Skip to content

Commit b8792c0

Browse files
ezyangfacebook-github-bot
authored andcommitted
Revert D18645954: add __torch_function__ API override mechanism
Test Plan: revert-hammer Differential Revision: D18645954 Original commit changeset: 54b5e4344d7a fbshipit-source-id: 4a7aebb483e6b001130d6f384ccc53c5a808ab13
1 parent a68b790 commit b8792c0

12 files changed

+61
-1961
lines changed

docs/source/notes/extending.rst

Lines changed: 1 addition & 280 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ Extending PyTorch
22
=================
33

44
In this note we'll cover ways of extending :mod:`torch.nn`,
5-
:mod:`torch.autograd`, :mod:`torch`, and writing custom C extensions utilizing our C
5+
:mod:`torch.autograd`, and writing custom C extensions utilizing our C
66
libraries.
77

88
Extending :mod:`torch.autograd`
@@ -204,285 +204,6 @@ This is how a ``Linear`` module can be implemented::
204204
self.in_features, self.out_features, self.bias is not None
205205
)
206206

207-
Extending :mod:`torch`
208-
----------------------
209-
210-
You can create custom types that emulate :class:`Tensor` by defining a custom
211-
class with methods that match :class:`Tensor`. But what if you want to be able
212-
to pass these types to functions like :func:`torch.add` in the top-level
213-
:mod:`torch` namespace that accept :class:`Tensor` operands?
214-
215-
If your custom python type defines a method named ``__torch_function__``, PyTorch
216-
will invoke your ``__torch_function__`` implementation when an instance of your
217-
custom class is passed to a function in the :mod:`torch` namespace. This makes
218-
it possible to define custom implementations for any of the functions in the
219-
:mod:`torch` namespace which your ``__torch_function__`` implementation can call,
220-
allowing your users to make use of your custom type with existing PyTorch
221-
workflows that they have already written for :class:`Tensor`. This works with
222-
"duck" types that are unrelated to :class:`Tensor` as well as user-defined
223-
subclasses of :class:`Tensor`.
224-
225-
Extending :mod:`torch` with a :class:`Tensor`-like type
226-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
227-
228-
.. note:: This functionality is inspired by the NumPy ``__array_function__``
229-
protocol. See `the NumPy documentation
230-
<https://docs.scipy.org/doc/numpy/user/basics.dispatch.html#basics-dispatch>`_
231-
and `NEP-0018
232-
<https://numpy.org/neps/nep-0018-array-function-protocol.html>`_ for
233-
more details.
234-
235-
To make this concrete, let's begin with a simple example that illustrates the
236-
API dispatch mechanism. We'll create a custom type that represents a 2D scalar
237-
tensor, parametrized by the order ``N`` and value along the diagonal entries,
238-
``value``::
239-
240-
class ScalarTensor(object):
241-
def __init__(self, N, value):
242-
self._N = N
243-
self._value = value
244-
245-
def __repr__(self):
246-
return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
247-
248-
def tensor(self):
249-
return self._value * torch.eye(self._N)
250-
251-
This first iteration of the design isn't very useful. The main functionality of
252-
``ScalarTensor`` is to provide a more compact string representation of a scalar
253-
tensor than in the base tensor class::
254-
255-
>>> d = ScalarTensor(5, 2)
256-
>>> d
257-
ScalarTensor(N=5, value=2)
258-
>>> d.tensor()
259-
tensor([[2., 0., 0., 0., 0.],
260-
[0., 2., 0., 0., 0.],
261-
[0., 0., 2., 0., 0.],
262-
[0., 0., 0., 2., 0.],
263-
[0., 0., 0., 0., 2.]])
264-
265-
If we try to use this object with the :mod:`torch` API, we will run
266-
into issues::
267-
268-
>>> import torch
269-
>>> torch.mean(d)
270-
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor
271-
272-
Adding a ``__torch_function__`` implementation to ``ScalarTensor`` makes it
273-
possible for the above operation to succeed. Let's re-do our implementation,
274-
this time adding a ``__torch_function__`` implementation::
275-
276-
HANDLED_FUNCTIONS = {}
277-
class ScalarTensor(object):
278-
def __init__(self, N, value):
279-
self._N = N
280-
self._value = value
281-
282-
def __repr__(self):
283-
return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
284-
285-
def tensor(self):
286-
return self._value * torch.eye(self._N)
287-
288-
def __torch_function__(self, func, args=(), kwargs=None):
289-
if kwargs is None:
290-
kwargs = {}
291-
if func not in HANDLED_FUNCTIONS:
292-
return NotImplemented
293-
return HANDLED_FUNCTIONS[func](*args, **kwargs)
294-
295-
The ``__torch_function__`` method takes three arguments: ``func``, a reference to
296-
the torch API function that is being overrided, ``args``, the tuple of arguments
297-
passed to the function, and ``kwargs``, the dict of keyword arguments passed to
298-
the function. It uses a global dispatch stable named ``HANDLED_FUNCTIONS`` to
299-
store custom implementations. The keys of this dictionary are functions in the
300-
``torch`` namespace and the values are implementations for ``ScalarTensor``.
301-
302-
.. note:: Using a global dispatch table is not a mandated part of the
303-
``__torch_function__`` API, it is just a useful design pattern for
304-
structuring your override implementations.
305-
306-
This class definition isn't quite enough to make ``torch.mean`` do the right
307-
thing when we pass it a ``ScalarTensor`` -- we also need to define an
308-
implementation for ``torch.mean`` for ``ScalarTensor`` operands and add the
309-
implementation to the ``HANDLED_FUNCTIONS`` dispatch table dictionary. One way
310-
of doing this is to define a decorator::
311-
312-
import functools
313-
def implements(torch_function):
314-
"""Register a torch function override for ScalarTensor"""
315-
@functools.wraps(torch_function)
316-
def decorator(func):
317-
HANDLED_FUNCTIONS[torch_function] = func
318-
return func
319-
return decorator
320-
321-
which can be applied to the implementation of our override::
322-
323-
@implements(torch.mean)
324-
def mean(input):
325-
return float(input._value) / input._N
326-
327-
With this change we can now use ``torch.mean`` with ``ScalarTensor``::
328-
329-
>>> d = ScalarTensor(5, 2)
330-
>>> torch.mean(d)
331-
0.4
332-
333-
Of course ``torch.mean`` is an example of the simplest kind of function to
334-
override since it only takes one operand. We can use the same machinery to
335-
override a function that takes more than one operand, any one of which might be
336-
a tensor or tensor-like that defines ``__torch_function__``, for example for
337-
:func:`torch.add`::
338-
339-
def ensure_tensor(data):
340-
if isinstance(data, ScalarTensor):
341-
return data.tensor()
342-
return torch.as_tensor(data)
343-
344-
@implements(torch.add)
345-
def add(input, other):
346-
try:
347-
if input._N == other._N:
348-
return ScalarTensor(input._N, input._value + other._value)
349-
else:
350-
raise ValueError("Shape mismatch!")
351-
except AttributeError:
352-
return torch.add(ensure_tensor(input), ensure_tensor(other))
353-
354-
This version has a fast path for when both operands are ``ScalarTensor``
355-
instances and also a slower path which degrades to converting the data to
356-
tensors when either operand is not a ``ScalarTensor``. That makes the override
357-
function correctly when either operand is a ``ScalarTensor`` or a regular
358-
:class:`Tensor`::
359-
360-
>>> s = ScalarTensor(2, 2)
361-
>>> torch.add(s, s)
362-
DiagonalTensor(N=2, value=4)
363-
>>> t = torch.tensor([[1, 1,], [1, 1]])
364-
>>> torch.add(s, t)
365-
tensor([[3., 1.],
366-
[1., 3.]])
367-
368-
Note that our implementation of ``add`` does not take ``alpha`` or ``out`` as
369-
keyword arguments like :func:`torch.add` does::
370-
371-
>>> torch.add(s, s, alpha=2)
372-
TypeError: add() got an unexpected keyword argument 'alpha'
373-
374-
For speed and flexibility the ``__torch_function__`` dispatch mechanism does not
375-
check that the signature of an override function matches the signature of the
376-
function being overrided in the :mod:`torch` API. For some applications ignoring
377-
optional arguments would be fine but to ensure full compatibility with
378-
:class:`Tensor`, user implementations of torch API functions should take care to
379-
exactly emulate the API of the function that is being overrided.
380-
381-
Functions in the :mod:`torch` API that do not have explicit overrides will
382-
return ``NotImplemented`` from ``__torch_function__``. If all operands with
383-
``__torch_function__`` defined on them return ``NotImplemented``, PyTorch will
384-
raise a ``TypeError``. This means that most of the time operations that do not
385-
have explicit overrides for a type will raise a ``TypeError`` when an instance
386-
of such a type is passed::
387-
388-
>>> torch.mul(s, 3)
389-
TypeError: no implementation found for 'torch.mul' on types that
390-
implement __torch_function__: [ScalarTensor]
391-
392-
In practice this means that if you would like to implement your overrides using
393-
a ``__torch_function__`` implementation along these lines, you will need to
394-
explicitly implement the full :mod:`torch` API or the entire subset of the API
395-
that you care about for your use case. This may be a tall order as the full
396-
:mod:`torch` API is quite extensive.
397-
398-
Another option is to not return ``NotImplemented`` for operations that are not
399-
handled but to instead pass a :class:`Tensor` to the original :mod:`torch`
400-
function when no override is available. For example, if we change our
401-
implementation of ``__torch_function__`` for ``ScalarTensor`` to the one below::
402-
403-
def __torch_function__(self, func, args=(), kwargs=None):
404-
if kwargs is None:
405-
kwargs = {}
406-
if func not in HANDLED_FUNCTIONS:
407-
args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
408-
return func(*args, **kwargs)
409-
return HANDLED_FUNCTIONS[func](*args, **kwargs)
410-
411-
Then :func:`torch.mul` will work correctly, although the return type will always
412-
be a :class:`Tensor` rather than a :class:`ScalarTensor`, even if both operands
413-
are :class:`ScalarTensor` instances::
414-
415-
>>> s = ScalarTensor(2, 2)
416-
>>> torch.mul(s, s)
417-
tensor([[4., 0.],
418-
[0., 4.]])
419-
420-
Also see the ``MetadataTensor`` example below for another variation on this
421-
pattern but instead always returns a ``MetadataTensor`` to propagate metadata
422-
through operations in the :mod:`torch` API.
423-
424-
Extending :mod:`torch` with a :class:`Tensor` wrapper type
425-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
426-
427-
Another useful case is a type that wraps a :class:`Tensor`, either as an
428-
attribute or via subclassing. Below we implement a special case of this sort of
429-
type, a ``MetadataTensor`` that attaches a dictionary of metadata to a
430-
:class:`Tensor` that is propagated through :mod:`torch` operations. Since this
431-
is a generic sort of wrapping for the full :mod:`torch` API, we do not need to
432-
individually implement each override so we can make the ``__torch_function__``
433-
implementation more permissive about what operations are allowed::
434-
435-
class MetadataTensor(object):
436-
def __init__(self, data, metadata=None, **kwargs):
437-
self._t = torch.as_tensor(data, **kwargs)
438-
self._metadata = metadata
439-
440-
def __repr__(self):
441-
return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
442-
443-
def __torch_function__(self, func, args=(), kwargs=None):
444-
if kwargs is None:
445-
kwargs = {}
446-
args = [a._t if hasattr(a, '_t') else a for a in args]
447-
ret = func(*args, **kwargs)
448-
return MetadataTensor(ret, metadata=self._metadata)
449-
450-
This simple implementation won't necessarily work with every function in the
451-
:mod:`torch` API but it is good enough to capture most common operations::
452-
453-
>>> metadata = {'owner': 'Ministry of Silly Walks'}
454-
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
455-
>>> t = torch.tensor([[1, 2], [1, 2]]])
456-
>>> torch.add(t, m)
457-
Metadata:
458-
{'owner': 'Ministry of Silly Walks'}
459-
460-
data:
461-
tensor([[2, 4],
462-
[4, 6]])
463-
>>> torch.mul(t, m)
464-
Metadata:
465-
{'owner': 'Ministry of Silly Walks'}
466-
467-
data:
468-
tensor([[1, 4],
469-
[3, 8]])
470-
471-
Operations on multiple types that define ``__torch_function__``
472-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
473-
474-
It is possible to use the torch API with multiple distinct types that each have
475-
a ``__torch_function__`` implementation, but special care must be taken. In such
476-
a case the rules are:
477-
478-
* The dispatch operation gathers all distinct implementations of
479-
``__torch_function__`` for each operand and calls them in order: subclasses
480-
before superclasses, and otherwise left to right in the operator expression.
481-
* If any value other than ``NotImplemented`` is returned, that value is
482-
returned as the result. Implementations can register that they do not
483-
implement an operation by returning ``NotImplemented``.
484-
* If all of the ``__torch_function__`` implementations return
485-
``NotImplemented``, PyTorch raises a ``TypeError``.
486207

487208
Writing custom C++ extensions
488209
-----------------------------

test/onnx/expect/TestOperators.test_frobenius_norm.expect

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ producer_name: "pytorch"
33
producer_version: "1.3"
44
graph {
55
node {
6-
input: "0"
7-
input: "0"
6+
input: "x"
7+
input: "x"
88
output: "1"
99
name: "Mul_0"
1010
op_type: "Mul"
@@ -34,7 +34,7 @@ graph {
3434
}
3535
name: "torch-jit-export"
3636
input {
37-
name: "0"
37+
name: "x"
3838
type {
3939
tensor_type {
4040
elem_type: 1

test/onnx/expect/TestOperators.test_meshgrid.expect

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ graph {
1717
}
1818
}
1919
node {
20-
input: "0"
20+
input: "x"
2121
input: "3"
2222
output: "4"
2323
name: "Reshape_1"
@@ -38,7 +38,7 @@ graph {
3838
}
3939
}
4040
node {
41-
input: "1"
41+
input: "y"
4242
input: "5"
4343
output: "6"
4444
name: "Reshape_3"
@@ -59,7 +59,7 @@ graph {
5959
}
6060
}
6161
node {
62-
input: "2"
62+
input: "z"
6363
input: "7"
6464
output: "8"
6565
name: "Reshape_5"
@@ -221,7 +221,7 @@ graph {
221221
}
222222
name: "torch-jit-export"
223223
input {
224-
name: "0"
224+
name: "x"
225225
type {
226226
tensor_type {
227227
elem_type: 1
@@ -234,7 +234,7 @@ graph {
234234
}
235235
}
236236
input {
237-
name: "1"
237+
name: "y"
238238
type {
239239
tensor_type {
240240
elem_type: 1
@@ -247,7 +247,7 @@ graph {
247247
}
248248
}
249249
input {
250-
name: "2"
250+
name: "z"
251251
type {
252252
tensor_type {
253253
elem_type: 1

test/onnx/expect/TestOperators.test_unique.expect

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ producer_name: "pytorch"
33
producer_version: "1.3"
44
graph {
55
node {
6-
input: "0"
6+
input: "x"
77
output: "1"
88
output: "2"
99
output: "3"
@@ -23,7 +23,7 @@ graph {
2323
}
2424
name: "torch-jit-export"
2525
input {
26-
name: "0"
26+
name: "x"
2727
type {
2828
tensor_type {
2929
elem_type: 1

test/run_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
'type_promotion',
6363
'jit_disabled',
6464
'function_schema',
65-
'overrides',
6665
]
6766

6867
# skip < 3.3 because mock is added in 3.3 and is used in rpc_spawn

0 commit comments

Comments
 (0)