Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fd01ddb

Browse files
committedMar 2, 2022
Don't unnecessarily wrap the elem in PythonTensor
Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor and an FX proxy, a PythonTensor *is a* regular CPU tensor, that also carries an FX proxy (that updates as we go along). This should fix #465 and it also fixed some expected failures in the test suite. Signed-off-by: Edward Z. Yang <[email protected]>
1 parent a8ccab5 commit fd01ddb

File tree

2 files changed

+38
-35
lines changed

2 files changed

+38
-35
lines changed
 

‎functorch/_src/python_key.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
USE_META = False
2020

2121

22+
@contextmanager
23+
def no_dispatch():
24+
guard = torch._C._DisableTorchDispatch()
25+
try:
26+
yield
27+
finally:
28+
del guard
29+
30+
2231
@contextmanager
2332
def pythonkey_decompose(decomposition_table):
2433
global CURRENT_DECOMPOSITION_TABLE
@@ -62,28 +71,30 @@ class PythonTensor(torch.Tensor):
6271

6372
@staticmethod
6473
def __new__(cls, elem, proxy, device=None):
65-
# The wrapping tensor (PythonTensor) is just a meta tensor, so it
66-
# doesn't hold any memory (meta tensor is generally the preferred type
67-
# of tensor you want to make a subclass from)...
68-
69-
r = torch.Tensor._make_wrapper_subclass(
70-
cls, elem.size(),
71-
strides=elem.stride(), storage_offset=elem.storage_offset(),
72-
dtype=elem.dtype, layout=elem.layout, requires_grad=elem.requires_grad,
73-
device=(elem.device if device is None else device),
74-
)
75-
76-
# ...the real tensor is held as an element on the tensor.
77-
if USE_META:
78-
r.elem = elem.to('meta')
79-
else:
80-
r.elem = elem
74+
# This is a hold-over from the (untested) meta codepath. Need to
75+
# figure out what I want to do here.
76+
assert device is None or device == elem.device
77+
78+
# Wrapping something in PythonTensor implicitly detaches
79+
# gradients. If something required grad, we will collect it as if it
80+
# were a leaf. A consequence of detaching in this way is you
81+
# need to maintain a parameter cache when translating tensors
82+
# into PythonTensor, so you don't create multiple copies of
83+
# a gradient (they are aliased, but they would count as independent
84+
# leaves). An alternate strategy would be to avoid implicitly
85+
# detaching and instead "catch" gradients as they exit the
86+
# PythonTensor boundary.
87+
# assert not elem.requires_grad or not torch.is_grad_enabled()
88+
89+
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
8190
r.proxy = proxy
8291
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
8392
return r
8493

8594
def __repr__(self):
86-
return f"PythonTensor({self.elem})"
95+
# This is a bit goofy but whatever. Should fix up _tensor_str.py to
96+
# work on subclasses when it calls tolist
97+
return f"PythonTensor({torch.Tensor._make_subclass(torch.Tensor, self)})"
8798

8899
__torch_function__ = _disabled_torch_function_impl
89100

@@ -99,9 +110,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
99110
def unwrap_proxy(e):
100111
return e.proxy if isinstance(e, PythonTensor) else e
101112

102-
def unwrap_tensor(e):
103-
return e.elem if isinstance(e, PythonTensor) else e
104-
105113
input_devices = [i.device for i in pytree.tree_flatten(args)[0] +
106114
pytree.tree_flatten(kwargs)[0] if isinstance(i, torch.Tensor)]
107115

@@ -115,17 +123,16 @@ def unwrap_tensor(e):
115123
# Kind of a hacky way to test if an op is in-place or not
116124
if func.__name__[-1] == "_" and func.__name__[0] != "_":
117125
args[0].proxy = proxy_out
118-
args = pytree.tree_map(unwrap_tensor, args)
119-
kwargs = pytree.tree_map(unwrap_tensor, kwargs)
120-
121-
try:
122-
real_out = func(*args, **kwargs)
123-
except NotImplementedError:
124-
args = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
125-
if isinstance(x, torch.Tensor) else x, args)
126-
kwargs = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
127-
if isinstance(x, torch.Tensor) else x, kwargs)
128-
real_out = func(*args, **kwargs)
126+
127+
with no_dispatch():
128+
try:
129+
real_out = func(*args, **kwargs)
130+
except NotImplementedError:
131+
args = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
132+
if isinstance(x, torch.Tensor) else x, args)
133+
kwargs = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
134+
if isinstance(x, torch.Tensor) else x, kwargs)
135+
real_out = func(*args, **kwargs)
129136

130137
def wrap_with_proxy(e, proxy):
131138
# Some ops (like native_batch_norm_backward) return undefined tensors that get

‎test/test_pythonkey.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def f(x):
192192
xfail('allclose'),
193193
xfail('nn.functional.dropout'),
194194
xfail('linalg.eigvals'),
195-
xfail('nn.functional.ctc_loss'),
196195
xfail('nn.functional.fractional_max_pool3d', device_type='cpu'),
197196
xfail('randn_like'), # randomness
198197
xfail('rand_like'), # randomness
@@ -355,11 +354,8 @@ class TestEagerFusionOpInfo(TestCase):
355354
# entries in here need don't work and need to be fixed.
356355
# Each one of these is a bug (or needs to be investigated)
357356
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', {
358-
xfail('__rmatmul__'),
359357
xfail('linalg.cholesky'),
360-
xfail('matmul'),
361358
skip('msort'),
362-
xfail('nn.functional.linear'),
363359
xfail('nn.functional.dropout'),
364360
xfail('polar'),
365361
xfail('special.zeta', 'grad'),

0 commit comments

Comments
 (0)
Please sign in to comment.