16
16
aten = torch .ops .aten
17
17
18
18
CURRENT_DECOMPOSITION_TABLE = {}
19
- USE_META = False
20
19
21
20
22
21
@contextmanager
23
- def pythonkey_decompose (decomposition_table ):
24
- global CURRENT_DECOMPOSITION_TABLE
25
- CURRENT_DECOMPOSITION_TABLE = decomposition_table
22
+ def no_dispatch ():
23
+ guard = torch ._C ._DisableTorchDispatch ()
26
24
try :
27
- yield CURRENT_DECOMPOSITION_TABLE
25
+ yield
28
26
finally :
29
- CURRENT_DECOMPOSITION_TABLE = {}
27
+ del guard
30
28
31
29
32
30
@contextmanager
33
- def pythonkey_meta ( ):
34
- global USE_META
35
- USE_META = True
31
+ def pythonkey_decompose ( decomposition_table ):
32
+ global CURRENT_DECOMPOSITION_TABLE
33
+ CURRENT_DECOMPOSITION_TABLE = decomposition_table
36
34
try :
37
- yield USE_META
35
+ yield CURRENT_DECOMPOSITION_TABLE
38
36
finally :
39
- USE_META = False
40
-
41
-
42
- def get_output_device (devices , op ):
43
- # The device propagation is a bit sketchy.
44
- # aten::index(CPU, CUDA) => CPU tensor
45
- # aten::index(CUDA, CPU) => CUDA tensor
46
- if op == aten .index :
47
- return devices [0 ]
48
- devices = list (set (devices ))
49
- if len (devices ) == 1 :
50
- return devices [0 ]
51
- else :
52
- for device in devices :
53
- if device .type == 'cuda' :
54
- return device
55
- raise RuntimeError ("Couldn't infer output device from input device" )
37
+ CURRENT_DECOMPOSITION_TABLE = {}
56
38
57
39
58
40
class PythonTensor (torch .Tensor ):
@@ -61,29 +43,27 @@ class PythonTensor(torch.Tensor):
61
43
__slots__ = ['elem' , 'proxy' ]
62
44
63
45
@staticmethod
64
- def __new__ (cls , elem , proxy , device = None ):
65
- # The wrapping tensor (PythonTensor) is just a meta tensor, so it
66
- # doesn't hold any memory (meta tensor is generally the preferred type
67
- # of tensor you want to make a subclass from)...
68
-
69
- r = torch .Tensor ._make_wrapper_subclass (
70
- cls , elem .size (),
71
- strides = elem .stride (), storage_offset = elem .storage_offset (),
72
- dtype = elem .dtype , layout = elem .layout , requires_grad = elem .requires_grad ,
73
- device = (elem .device if device is None else device ),
74
- )
75
-
76
- # ...the real tensor is held as an element on the tensor.
77
- if USE_META :
78
- r .elem = elem .to ('meta' )
79
- else :
80
- r .elem = elem
46
+ def __new__ (cls , elem , proxy ):
47
+ # Wrapping something in PythonTensor implicitly detaches
48
+ # gradients. If something required grad, we will collect it as if it
49
+ # were a leaf. A consequence of detaching in this way is you
50
+ # need to maintain a parameter cache when translating tensors
51
+ # into PythonTensor, so you don't create multiple copies of
52
+ # a gradient (they are aliased, but they would count as independent
53
+ # leaves). An alternate strategy would be to avoid implicitly
54
+ # detaching and instead "catch" gradients as they exit the
55
+ # PythonTensor boundary.
56
+ # assert not elem.requires_grad or not torch.is_grad_enabled()
57
+
58
+ r = torch .Tensor ._make_subclass (cls , elem , elem .requires_grad )
81
59
r .proxy = proxy
82
60
proxy .node .meta ['tensor_meta' ] = _extract_tensor_metadata (r )
83
61
return r
84
62
85
63
def __repr__ (self ):
86
- return f"PythonTensor({ self .elem } )"
64
+ # This is a bit goofy but whatever. Should fix up _tensor_str.py to
65
+ # work on subclasses when it calls tolist
66
+ return f"PythonTensor({ torch .Tensor ._make_subclass (torch .Tensor , self )} )"
87
67
88
68
__torch_function__ = _disabled_torch_function_impl
89
69
@@ -99,14 +79,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
99
79
def unwrap_proxy (e ):
100
80
return e .proxy if isinstance (e , PythonTensor ) else e
101
81
102
- def unwrap_tensor (e ):
103
- return e .elem if isinstance (e , PythonTensor ) else e
104
-
105
- input_devices = [i .device for i in pytree .tree_flatten (args )[0 ] +
106
- pytree .tree_flatten (kwargs )[0 ] if isinstance (i , torch .Tensor )]
107
-
108
- output_device = get_output_device (input_devices , func )
109
-
110
82
proxy_args = pytree .tree_map (unwrap_proxy , args )
111
83
proxy_kwargs = pytree .tree_map (unwrap_proxy , kwargs )
112
84
@@ -115,16 +87,8 @@ def unwrap_tensor(e):
115
87
# Kind of a hacky way to test if an op is in-place or not
116
88
if func .__name__ [- 1 ] == "_" and func .__name__ [0 ] != "_" :
117
89
args [0 ].proxy = proxy_out
118
- args = pytree .tree_map (unwrap_tensor , args )
119
- kwargs = pytree .tree_map (unwrap_tensor , kwargs )
120
90
121
- try :
122
- real_out = func (* args , ** kwargs )
123
- except NotImplementedError :
124
- args = pytree .tree_map (lambda x : torch .ones_like (x , device = output_device )
125
- if isinstance (x , torch .Tensor ) else x , args )
126
- kwargs = pytree .tree_map (lambda x : torch .ones_like (x , device = output_device )
127
- if isinstance (x , torch .Tensor ) else x , kwargs )
91
+ with no_dispatch ():
128
92
real_out = func (* args , ** kwargs )
129
93
130
94
def wrap_with_proxy (e , proxy ):
@@ -135,7 +99,7 @@ def wrap_with_proxy(e, proxy):
135
99
if e is None :
136
100
e = torch .empty (())
137
101
if type (e ) == torch .Tensor :
138
- return PythonTensor (e , proxy , output_device )
102
+ return PythonTensor (e , proxy )
139
103
else :
140
104
return e
141
105
if isinstance (real_out , tuple ):
0 commit comments