Skip to content

Commit 1ac65f7

Browse files
ezyangzou3519
authored andcommitted
[functorch] Don't unnecessarily wrap the elem in PythonTensor (pytorch/functorch#554)
* Don't unnecessarily wrap the elem in PythonTensor Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor and an FX proxy, a PythonTensor *is a* regular CPU tensor, that also carries an FX proxy (that updates as we go along). This should fix pytorch/functorch#465 and it also fixed some expected failures in the test suite. This kills the meta variant logic entirely; maybe some other time we'll try to bring it back. Signed-off-by: Edward Z. Yang <[email protected]>
1 parent 60357cb commit 1ac65f7

File tree

3 files changed

+28
-68
lines changed

3 files changed

+28
-68
lines changed

functorch/functorch/_src/python_key.py

Lines changed: 27 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16,43 +16,25 @@
1616
aten = torch.ops.aten
1717

1818
CURRENT_DECOMPOSITION_TABLE = {}
19-
USE_META = False
2019

2120

2221
@contextmanager
23-
def pythonkey_decompose(decomposition_table):
24-
global CURRENT_DECOMPOSITION_TABLE
25-
CURRENT_DECOMPOSITION_TABLE = decomposition_table
22+
def no_dispatch():
23+
guard = torch._C._DisableTorchDispatch()
2624
try:
27-
yield CURRENT_DECOMPOSITION_TABLE
25+
yield
2826
finally:
29-
CURRENT_DECOMPOSITION_TABLE = {}
27+
del guard
3028

3129

3230
@contextmanager
33-
def pythonkey_meta():
34-
global USE_META
35-
USE_META = True
31+
def pythonkey_decompose(decomposition_table):
32+
global CURRENT_DECOMPOSITION_TABLE
33+
CURRENT_DECOMPOSITION_TABLE = decomposition_table
3634
try:
37-
yield USE_META
35+
yield CURRENT_DECOMPOSITION_TABLE
3836
finally:
39-
USE_META = False
40-
41-
42-
def get_output_device(devices, op):
43-
# The device propagation is a bit sketchy.
44-
# aten::index(CPU, CUDA) => CPU tensor
45-
# aten::index(CUDA, CPU) => CUDA tensor
46-
if op == aten.index:
47-
return devices[0]
48-
devices = list(set(devices))
49-
if len(devices) == 1:
50-
return devices[0]
51-
else:
52-
for device in devices:
53-
if device.type == 'cuda':
54-
return device
55-
raise RuntimeError("Couldn't infer output device from input device")
37+
CURRENT_DECOMPOSITION_TABLE = {}
5638

5739

5840
class PythonTensor(torch.Tensor):
@@ -61,29 +43,27 @@ class PythonTensor(torch.Tensor):
6143
__slots__ = ['elem', 'proxy']
6244

6345
@staticmethod
64-
def __new__(cls, elem, proxy, device=None):
65-
# The wrapping tensor (PythonTensor) is just a meta tensor, so it
66-
# doesn't hold any memory (meta tensor is generally the preferred type
67-
# of tensor you want to make a subclass from)...
68-
69-
r = torch.Tensor._make_wrapper_subclass(
70-
cls, elem.size(),
71-
strides=elem.stride(), storage_offset=elem.storage_offset(),
72-
dtype=elem.dtype, layout=elem.layout, requires_grad=elem.requires_grad,
73-
device=(elem.device if device is None else device),
74-
)
75-
76-
# ...the real tensor is held as an element on the tensor.
77-
if USE_META:
78-
r.elem = elem.to('meta')
79-
else:
80-
r.elem = elem
46+
def __new__(cls, elem, proxy):
47+
# Wrapping something in PythonTensor implicitly detaches
48+
# gradients. If something required grad, we will collect it as if it
49+
# were a leaf. A consequence of detaching in this way is you
50+
# need to maintain a parameter cache when translating tensors
51+
# into PythonTensor, so you don't create multiple copies of
52+
# a gradient (they are aliased, but they would count as independent
53+
# leaves). An alternate strategy would be to avoid implicitly
54+
# detaching and instead "catch" gradients as they exit the
55+
# PythonTensor boundary.
56+
# assert not elem.requires_grad or not torch.is_grad_enabled()
57+
58+
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
8159
r.proxy = proxy
8260
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
8361
return r
8462

8563
def __repr__(self):
86-
return f"PythonTensor({self.elem})"
64+
# This is a bit goofy but whatever. Should fix up _tensor_str.py to
65+
# work on subclasses when it calls tolist
66+
return f"PythonTensor({torch.Tensor._make_subclass(torch.Tensor, self)})"
8767

8868
__torch_function__ = _disabled_torch_function_impl
8969

@@ -99,14 +79,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
9979
def unwrap_proxy(e):
10080
return e.proxy if isinstance(e, PythonTensor) else e
10181

102-
def unwrap_tensor(e):
103-
return e.elem if isinstance(e, PythonTensor) else e
104-
105-
input_devices = [i.device for i in pytree.tree_flatten(args)[0] +
106-
pytree.tree_flatten(kwargs)[0] if isinstance(i, torch.Tensor)]
107-
108-
output_device = get_output_device(input_devices, func)
109-
11082
proxy_args = pytree.tree_map(unwrap_proxy, args)
11183
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
11284

@@ -115,16 +87,8 @@ def unwrap_tensor(e):
11587
# Kind of a hacky way to test if an op is in-place or not
11688
if func.__name__[-1] == "_" and func.__name__[0] != "_":
11789
args[0].proxy = proxy_out
118-
args = pytree.tree_map(unwrap_tensor, args)
119-
kwargs = pytree.tree_map(unwrap_tensor, kwargs)
12090

121-
try:
122-
real_out = func(*args, **kwargs)
123-
except NotImplementedError:
124-
args = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
125-
if isinstance(x, torch.Tensor) else x, args)
126-
kwargs = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
127-
if isinstance(x, torch.Tensor) else x, kwargs)
91+
with no_dispatch():
12892
real_out = func(*args, **kwargs)
12993

13094
def wrap_with_proxy(e, proxy):
@@ -135,7 +99,7 @@ def wrap_with_proxy(e, proxy):
13599
if e is None:
136100
e = torch.empty(())
137101
if type(e) == torch.Tensor:
138-
return PythonTensor(e, proxy, output_device)
102+
return PythonTensor(e, proxy)
139103
else:
140104
return e
141105
if isinstance(real_out, tuple):

functorch/functorch/compile/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .._src.operator_authoring import pointwise_operator
2-
from .._src.python_key import pythonkey_decompose, pythonkey_meta
2+
from .._src.python_key import pythonkey_decompose
33
from .._src.decompositions import register_decomposition, decomposition_table
44
from .._src.fx_minifier import minifier, check_nvfuser_subprocess
55
from .._src.aot_autograd import (

functorch/test/test_pythonkey.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def f(x):
192192
xfail('allclose'),
193193
xfail('nn.functional.dropout'),
194194
xfail('linalg.eigvals'),
195-
xfail('nn.functional.ctc_loss'),
196195
xfail('nn.functional.fractional_max_pool3d', device_type='cpu'),
197196
xfail('randn_like'), # randomness
198197
xfail('rand_like'), # randomness
@@ -355,11 +354,8 @@ class TestEagerFusionOpInfo(TestCase):
355354
# entries in here need don't work and need to be fixed.
356355
# Each one of these is a bug (or needs to be investigated)
357356
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', {
358-
xfail('__rmatmul__'),
359357
xfail('linalg.cholesky'),
360-
xfail('matmul'),
361358
skip('msort'),
362-
xfail('nn.functional.linear'),
363359
xfail('nn.functional.dropout'),
364360
xfail('polar'),
365361
xfail('special.zeta', 'grad'),

0 commit comments

Comments
 (0)