@@ -2,7 +2,7 @@ Extending PyTorch
2
2
=================
3
3
4
4
In this note we'll cover ways of extending :mod: `torch.nn `,
5
- :mod: `torch.autograd `, :mod: ` torch `, and writing custom C extensions utilizing our C
5
+ :mod: `torch.autograd `, and writing custom C extensions utilizing our C
6
6
libraries.
7
7
8
8
Extending :mod: `torch.autograd `
@@ -204,285 +204,6 @@ This is how a ``Linear`` module can be implemented::
204
204
self.in_features, self.out_features, self.bias is not None
205
205
)
206
206
207
- Extending :mod: `torch `
208
- ----------------------
209
-
210
- You can create custom types that emulate :class: `Tensor ` by defining a custom
211
- class with methods that match :class: `Tensor `. But what if you want to be able
212
- to pass these types to functions like :func: `torch.add ` in the top-level
213
- :mod: `torch ` namespace that accept :class: `Tensor ` operands?
214
-
215
- If your custom python type defines a method named ``__torch_function__ ``, PyTorch
216
- will invoke your ``__torch_function__ `` implementation when an instance of your
217
- custom class is passed to a function in the :mod: `torch ` namespace. This makes
218
- it possible to define custom implementations for any of the functions in the
219
- :mod: `torch ` namespace which your ``__torch_function__ `` implementation can call,
220
- allowing your users to make use of your custom type with existing PyTorch
221
- workflows that they have already written for :class: `Tensor `. This works with
222
- "duck" types that are unrelated to :class: `Tensor ` as well as user-defined
223
- subclasses of :class: `Tensor `.
224
-
225
- Extending :mod: `torch ` with a :class: `Tensor `-like type
226
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
227
-
228
- .. note :: This functionality is inspired by the NumPy ``__array_function__``
229
- protocol. See `the NumPy documentation
230
- <https://docs.scipy.org/doc/numpy/user/basics.dispatch.html#basics-dispatch> `_
231
- and `NEP-0018
232
- <https://numpy.org/neps/nep-0018-array-function-protocol.html> `_ for
233
- more details.
234
-
235
- To make this concrete, let's begin with a simple example that illustrates the
236
- API dispatch mechanism. We'll create a custom type that represents a 2D scalar
237
- tensor, parametrized by the order ``N `` and value along the diagonal entries,
238
- ``value ``::
239
-
240
- class ScalarTensor(object):
241
- def __init__(self, N, value):
242
- self._N = N
243
- self._value = value
244
-
245
- def __repr__(self):
246
- return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
247
-
248
- def tensor(self):
249
- return self._value * torch.eye(self._N)
250
-
251
- This first iteration of the design isn't very useful. The main functionality of
252
- ``ScalarTensor `` is to provide a more compact string representation of a scalar
253
- tensor than in the base tensor class::
254
-
255
- >>> d = ScalarTensor(5, 2)
256
- >>> d
257
- ScalarTensor(N=5, value=2)
258
- >>> d.tensor()
259
- tensor([[2., 0., 0., 0., 0.],
260
- [0., 2., 0., 0., 0.],
261
- [0., 0., 2., 0., 0.],
262
- [0., 0., 0., 2., 0.],
263
- [0., 0., 0., 0., 2.]])
264
-
265
- If we try to use this object with the :mod: `torch ` API, we will run
266
- into issues::
267
-
268
- >>> import torch
269
- >>> torch.mean(d)
270
- TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor
271
-
272
- Adding a ``__torch_function__ `` implementation to ``ScalarTensor `` makes it
273
- possible for the above operation to succeed. Let's re-do our implementation,
274
- this time adding a ``__torch_function__ `` implementation::
275
-
276
- HANDLED_FUNCTIONS = {}
277
- class ScalarTensor(object):
278
- def __init__(self, N, value):
279
- self._N = N
280
- self._value = value
281
-
282
- def __repr__(self):
283
- return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
284
-
285
- def tensor(self):
286
- return self._value * torch.eye(self._N)
287
-
288
- def __torch_function__(self, func, args=(), kwargs=None):
289
- if kwargs is None:
290
- kwargs = {}
291
- if func not in HANDLED_FUNCTIONS:
292
- return NotImplemented
293
- return HANDLED_FUNCTIONS[func](*args, **kwargs)
294
-
295
- The ``__torch_function__ `` method takes three arguments: ``func ``, a reference to
296
- the torch API function that is being overrided, ``args ``, the tuple of arguments
297
- passed to the function, and ``kwargs ``, the dict of keyword arguments passed to
298
- the function. It uses a global dispatch stable named ``HANDLED_FUNCTIONS `` to
299
- store custom implementations. The keys of this dictionary are functions in the
300
- ``torch `` namespace and the values are implementations for ``ScalarTensor ``.
301
-
302
- .. note :: Using a global dispatch table is not a mandated part of the
303
- ``__torch_function__ `` API, it is just a useful design pattern for
304
- structuring your override implementations.
305
-
306
- This class definition isn't quite enough to make ``torch.mean `` do the right
307
- thing when we pass it a ``ScalarTensor `` -- we also need to define an
308
- implementation for ``torch.mean `` for ``ScalarTensor `` operands and add the
309
- implementation to the ``HANDLED_FUNCTIONS `` dispatch table dictionary. One way
310
- of doing this is to define a decorator::
311
-
312
- import functools
313
- def implements(torch_function):
314
- """Register a torch function override for ScalarTensor"""
315
- @functools.wraps(torch_function)
316
- def decorator(func):
317
- HANDLED_FUNCTIONS[torch_function] = func
318
- return func
319
- return decorator
320
-
321
- which can be applied to the implementation of our override::
322
-
323
- @implements(torch.mean)
324
- def mean(input):
325
- return float(input._value) / input._N
326
-
327
- With this change we can now use ``torch.mean `` with ``ScalarTensor ``::
328
-
329
- >>> d = ScalarTensor(5, 2)
330
- >>> torch.mean(d)
331
- 0.4
332
-
333
- Of course ``torch.mean `` is an example of the simplest kind of function to
334
- override since it only takes one operand. We can use the same machinery to
335
- override a function that takes more than one operand, any one of which might be
336
- a tensor or tensor-like that defines ``__torch_function__ ``, for example for
337
- :func: `torch.add `::
338
-
339
- def ensure_tensor(data):
340
- if isinstance(data, ScalarTensor):
341
- return data.tensor()
342
- return torch.as_tensor(data)
343
-
344
- @implements(torch.add)
345
- def add(input, other):
346
- try:
347
- if input._N == other._N:
348
- return ScalarTensor(input._N, input._value + other._value)
349
- else:
350
- raise ValueError("Shape mismatch!")
351
- except AttributeError:
352
- return torch.add(ensure_tensor(input), ensure_tensor(other))
353
-
354
- This version has a fast path for when both operands are ``ScalarTensor ``
355
- instances and also a slower path which degrades to converting the data to
356
- tensors when either operand is not a ``ScalarTensor ``. That makes the override
357
- function correctly when either operand is a ``ScalarTensor `` or a regular
358
- :class: `Tensor `::
359
-
360
- >>> s = ScalarTensor(2, 2)
361
- >>> torch.add(s, s)
362
- DiagonalTensor(N=2, value=4)
363
- >>> t = torch.tensor([[1, 1,], [1, 1]])
364
- >>> torch.add(s, t)
365
- tensor([[3., 1.],
366
- [1., 3.]])
367
-
368
- Note that our implementation of ``add `` does not take ``alpha `` or ``out `` as
369
- keyword arguments like :func: `torch.add ` does::
370
-
371
- >>> torch.add(s, s, alpha=2)
372
- TypeError: add() got an unexpected keyword argument 'alpha'
373
-
374
- For speed and flexibility the ``__torch_function__ `` dispatch mechanism does not
375
- check that the signature of an override function matches the signature of the
376
- function being overrided in the :mod: `torch ` API. For some applications ignoring
377
- optional arguments would be fine but to ensure full compatibility with
378
- :class: `Tensor `, user implementations of torch API functions should take care to
379
- exactly emulate the API of the function that is being overrided.
380
-
381
- Functions in the :mod: `torch ` API that do not have explicit overrides will
382
- return ``NotImplemented `` from ``__torch_function__ ``. If all operands with
383
- ``__torch_function__ `` defined on them return ``NotImplemented ``, PyTorch will
384
- raise a ``TypeError ``. This means that most of the time operations that do not
385
- have explicit overrides for a type will raise a ``TypeError `` when an instance
386
- of such a type is passed::
387
-
388
- >>> torch.mul(s, 3)
389
- TypeError: no implementation found for 'torch.mul' on types that
390
- implement __torch_function__: [ScalarTensor]
391
-
392
- In practice this means that if you would like to implement your overrides using
393
- a ``__torch_function__ `` implementation along these lines, you will need to
394
- explicitly implement the full :mod: `torch ` API or the entire subset of the API
395
- that you care about for your use case. This may be a tall order as the full
396
- :mod: `torch ` API is quite extensive.
397
-
398
- Another option is to not return ``NotImplemented `` for operations that are not
399
- handled but to instead pass a :class: `Tensor ` to the original :mod: `torch `
400
- function when no override is available. For example, if we change our
401
- implementation of ``__torch_function__ `` for ``ScalarTensor `` to the one below::
402
-
403
- def __torch_function__(self, func, args=(), kwargs=None):
404
- if kwargs is None:
405
- kwargs = {}
406
- if func not in HANDLED_FUNCTIONS:
407
- args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
408
- return func(*args, **kwargs)
409
- return HANDLED_FUNCTIONS[func](*args, **kwargs)
410
-
411
- Then :func: `torch.mul ` will work correctly, although the return type will always
412
- be a :class: `Tensor ` rather than a :class: `ScalarTensor `, even if both operands
413
- are :class: `ScalarTensor ` instances::
414
-
415
- >>> s = ScalarTensor(2, 2)
416
- >>> torch.mul(s, s)
417
- tensor([[4., 0.],
418
- [0., 4.]])
419
-
420
- Also see the ``MetadataTensor `` example below for another variation on this
421
- pattern but instead always returns a ``MetadataTensor `` to propagate metadata
422
- through operations in the :mod: `torch ` API.
423
-
424
- Extending :mod: `torch ` with a :class: `Tensor ` wrapper type
425
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
426
-
427
- Another useful case is a type that wraps a :class: `Tensor `, either as an
428
- attribute or via subclassing. Below we implement a special case of this sort of
429
- type, a ``MetadataTensor `` that attaches a dictionary of metadata to a
430
- :class: `Tensor ` that is propagated through :mod: `torch ` operations. Since this
431
- is a generic sort of wrapping for the full :mod: `torch ` API, we do not need to
432
- individually implement each override so we can make the ``__torch_function__ ``
433
- implementation more permissive about what operations are allowed::
434
-
435
- class MetadataTensor(object):
436
- def __init__(self, data, metadata=None, **kwargs):
437
- self._t = torch.as_tensor(data, **kwargs)
438
- self._metadata = metadata
439
-
440
- def __repr__(self):
441
- return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
442
-
443
- def __torch_function__(self, func, args=(), kwargs=None):
444
- if kwargs is None:
445
- kwargs = {}
446
- args = [a._t if hasattr(a, '_t') else a for a in args]
447
- ret = func(*args, **kwargs)
448
- return MetadataTensor(ret, metadata=self._metadata)
449
-
450
- This simple implementation won't necessarily work with every function in the
451
- :mod: `torch ` API but it is good enough to capture most common operations::
452
-
453
- >>> metadata = {'owner': 'Ministry of Silly Walks'}
454
- >>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
455
- >>> t = torch.tensor([[1, 2], [1, 2]]])
456
- >>> torch.add(t, m)
457
- Metadata:
458
- {'owner': 'Ministry of Silly Walks'}
459
-
460
- data:
461
- tensor([[2, 4],
462
- [4, 6]])
463
- >>> torch.mul(t, m)
464
- Metadata:
465
- {'owner': 'Ministry of Silly Walks'}
466
-
467
- data:
468
- tensor([[1, 4],
469
- [3, 8]])
470
-
471
- Operations on multiple types that define ``__torch_function__ ``
472
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
473
-
474
- It is possible to use the torch API with multiple distinct types that each have
475
- a ``__torch_function__ `` implementation, but special care must be taken. In such
476
- a case the rules are:
477
-
478
- * The dispatch operation gathers all distinct implementations of
479
- ``__torch_function__ `` for each operand and calls them in order: subclasses
480
- before superclasses, and otherwise left to right in the operator expression.
481
- * If any value other than ``NotImplemented `` is returned, that value is
482
- returned as the result. Implementations can register that they do not
483
- implement an operation by returning ``NotImplemented ``.
484
- * If all of the ``__torch_function__ `` implementations return
485
- ``NotImplemented ``, PyTorch raises a ``TypeError ``.
486
207
487
208
Writing custom C++ extensions
488
209
-----------------------------
0 commit comments