19
19
USE_META = False
20
20
21
21
22
+ @contextmanager
23
+ def no_dispatch ():
24
+ guard = torch ._C ._DisableTorchDispatch ()
25
+ try :
26
+ yield
27
+ finally :
28
+ del guard
29
+
30
+
22
31
@contextmanager
23
32
def pythonkey_decompose (decomposition_table ):
24
33
global CURRENT_DECOMPOSITION_TABLE
@@ -62,28 +71,30 @@ class PythonTensor(torch.Tensor):
62
71
63
72
@staticmethod
64
73
def __new__ (cls , elem , proxy , device = None ):
65
- # The wrapping tensor (PythonTensor) is just a meta tensor, so it
66
- # doesn't hold any memory (meta tensor is generally the preferred type
67
- # of tensor you want to make a subclass from)...
68
-
69
- r = torch . Tensor . _make_wrapper_subclass (
70
- cls , elem . size (),
71
- strides = elem . stride (), storage_offset = elem . storage_offset (),
72
- dtype = elem . dtype , layout = elem . layout , requires_grad = elem . requires_grad ,
73
- device = ( elem . device if device is None else device ),
74
- )
75
-
76
- # ...the real tensor is held as an element on the tensor.
77
- if USE_META :
78
- r . elem = elem . to ( 'meta' )
79
- else :
80
- r . elem = elem
74
+ # This is a hold-over from the (untested) meta codepath. Need to
75
+ # figure out what I want to do here.
76
+ assert device is None or device == elem . device
77
+
78
+ # Wrapping something in PythonTensor implicitly detaches
79
+ # gradients. If something required grad, we will collect it as if it
80
+ # were a leaf. A consequence of detaching in this way is you
81
+ # need to maintain a parameter cache when translating tensors
82
+ # into PythonTensor, so you don't create multiple copies of
83
+ # a gradient (they are aliased, but they would count as independent
84
+ # leaves). An alternate strategy would be to avoid implicitly
85
+ # detaching and instead "catch" gradients as they exit the
86
+ # PythonTensor boundary.
87
+ # assert not elem.requires_grad or not torch.is_grad_enabled( )
88
+
89
+ r = torch . Tensor . _make_subclass ( cls , elem , elem . requires_grad )
81
90
r .proxy = proxy
82
91
proxy .node .meta ['tensor_meta' ] = _extract_tensor_metadata (r )
83
92
return r
84
93
85
94
def __repr__ (self ):
86
- return f"PythonTensor({ self .elem } )"
95
+ # This is a bit goofy but whatever. Should fix up _tensor_str.py to
96
+ # work on subclasses when it calls tolist
97
+ return f"PythonTensor({ torch .Tensor ._make_subclass (torch .Tensor , self )} )"
87
98
88
99
__torch_function__ = _disabled_torch_function_impl
89
100
@@ -99,9 +110,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
99
110
def unwrap_proxy (e ):
100
111
return e .proxy if isinstance (e , PythonTensor ) else e
101
112
102
- def unwrap_tensor (e ):
103
- return e .elem if isinstance (e , PythonTensor ) else e
104
-
105
113
input_devices = [i .device for i in pytree .tree_flatten (args )[0 ] +
106
114
pytree .tree_flatten (kwargs )[0 ] if isinstance (i , torch .Tensor )]
107
115
@@ -115,17 +123,16 @@ def unwrap_tensor(e):
115
123
# Kind of a hacky way to test if an op is in-place or not
116
124
if func .__name__ [- 1 ] == "_" and func .__name__ [0 ] != "_" :
117
125
args [0 ].proxy = proxy_out
118
- args = pytree .tree_map (unwrap_tensor , args )
119
- kwargs = pytree .tree_map (unwrap_tensor , kwargs )
120
-
121
- try :
122
- real_out = func (* args , ** kwargs )
123
- except NotImplementedError :
124
- args = pytree .tree_map (lambda x : torch .ones_like (x , device = output_device )
125
- if isinstance (x , torch .Tensor ) else x , args )
126
- kwargs = pytree .tree_map (lambda x : torch .ones_like (x , device = output_device )
127
- if isinstance (x , torch .Tensor ) else x , kwargs )
128
- real_out = func (* args , ** kwargs )
126
+
127
+ with no_dispatch ():
128
+ try :
129
+ real_out = func (* args , ** kwargs )
130
+ except NotImplementedError :
131
+ args = pytree .tree_map (lambda x : torch .ones_like (x , device = output_device )
132
+ if isinstance (x , torch .Tensor ) else x , args )
133
+ kwargs = pytree .tree_map (lambda x : torch .ones_like (x , device = output_device )
134
+ if isinstance (x , torch .Tensor ) else x , kwargs )
135
+ real_out = func (* args , ** kwargs )
129
136
130
137
def wrap_with_proxy (e , proxy ):
131
138
# Some ops (like native_batch_norm_backward) return undefined tensors that get
0 commit comments