@@ -33,14 +33,9 @@ def _to_tensor(
33
33
def wrap_like (cls : Type [D ], other : D , tensor : torch .Tensor ) -> D :
34
34
return tensor .as_subclass (cls )
35
35
36
- _NO_WRAPPING_EXCEPTIONS = {
37
- torch .Tensor .clone : lambda cls , input , output : cls .wrap_like (input , output ),
38
- torch .Tensor .to : lambda cls , input , output : cls .wrap_like (input , output ),
39
- torch .Tensor .detach : lambda cls , input , output : cls .wrap_like (input , output ),
40
- # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
41
- # retains the type automatically
42
- torch .Tensor .requires_grad_ : lambda cls , input , output : output ,
43
- }
36
+ # The ops in this set are those that should *preserve* the Datapoint type,
37
+ # i.e. they are exceptions to the "no wrapping" rule.
38
+ _NO_WRAPPING_EXCEPTIONS = {torch .Tensor .clone , torch .Tensor .to , torch .Tensor .detach , torch .Tensor .requires_grad_ }
44
39
45
40
@classmethod
46
41
def __torch_function__ (
@@ -76,22 +71,21 @@ def __torch_function__(
76
71
with DisableTorchFunctionSubclass ():
77
72
output = func (* args , ** kwargs or dict ())
78
73
79
- wrapper = cls ._NO_WRAPPING_EXCEPTIONS . get ( func )
80
- # Apart from `func` needing to be an exception, we also require the primary operand, i.e. `args[0]`, to be
74
+ if func in cls ._NO_WRAPPING_EXCEPTIONS and isinstance ( args [ 0 ], cls ):
75
+ # We also require the primary operand, i.e. `args[0]`, to be
81
76
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
82
77
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
83
78
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
84
79
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
85
80
# be wrapped into a `datapoints.Image`.
86
- if wrapper and isinstance (args [0 ], cls ):
87
- return wrapper (cls , args [0 ], output )
81
+ return cls .wrap_like (args [0 ], output )
88
82
89
- # Inplace `func`'s, canonically identified with a trailing underscore in their name like `.add_(...)`,
90
- # will retain the input type. Thus, we need to unwrap here.
91
- if isinstance ( output , cls ):
92
- return output .as_subclass (torch .Tensor )
83
+ if isinstance ( output , cls ):
84
+ # DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
85
+ # so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
86
+ return output .as_subclass (torch .Tensor )
93
87
94
- return output
88
+ return output
95
89
96
90
def _make_repr (self , ** kwargs : Any ) -> str :
97
91
# This is a poor man's implementation of the proposal in https://github.com/pytorch/pytorch/issues/76532.
0 commit comments