Skip to content

Don't unnecessarily wrap the elem in PythonTensor #554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 27 additions & 63 deletions functorch/_src/python_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,25 @@
aten = torch.ops.aten

CURRENT_DECOMPOSITION_TABLE = {}
USE_META = False


@contextmanager
def pythonkey_decompose(decomposition_table):
global CURRENT_DECOMPOSITION_TABLE
CURRENT_DECOMPOSITION_TABLE = decomposition_table
def no_dispatch():
guard = torch._C._DisableTorchDispatch()
try:
yield CURRENT_DECOMPOSITION_TABLE
yield
finally:
CURRENT_DECOMPOSITION_TABLE = {}
del guard


@contextmanager
def pythonkey_meta():
global USE_META
USE_META = True
def pythonkey_decompose(decomposition_table):
global CURRENT_DECOMPOSITION_TABLE
CURRENT_DECOMPOSITION_TABLE = decomposition_table
try:
yield USE_META
yield CURRENT_DECOMPOSITION_TABLE
finally:
USE_META = False


def get_output_device(devices, op):
# The device propagation is a bit sketchy.
# aten::index(CPU, CUDA) => CPU tensor
# aten::index(CUDA, CPU) => CUDA tensor
if op == aten.index:
return devices[0]
devices = list(set(devices))
if len(devices) == 1:
return devices[0]
else:
for device in devices:
if device.type == 'cuda':
return device
raise RuntimeError("Couldn't infer output device from input device")
CURRENT_DECOMPOSITION_TABLE = {}


class PythonTensor(torch.Tensor):
Expand All @@ -61,29 +43,27 @@ class PythonTensor(torch.Tensor):
__slots__ = ['elem', 'proxy']

@staticmethod
def __new__(cls, elem, proxy, device=None):
# The wrapping tensor (PythonTensor) is just a meta tensor, so it
# doesn't hold any memory (meta tensor is generally the preferred type
# of tensor you want to make a subclass from)...

r = torch.Tensor._make_wrapper_subclass(
cls, elem.size(),
strides=elem.stride(), storage_offset=elem.storage_offset(),
dtype=elem.dtype, layout=elem.layout, requires_grad=elem.requires_grad,
device=(elem.device if device is None else device),
)

# ...the real tensor is held as an element on the tensor.
if USE_META:
r.elem = elem.to('meta')
else:
r.elem = elem
def __new__(cls, elem, proxy):
# Wrapping something in PythonTensor implicitly detaches
# gradients. If something required grad, we will collect it as if it
# were a leaf. A consequence of detaching in this way is you
# need to maintain a parameter cache when translating tensors
# into PythonTensor, so you don't create multiple copies of
# a gradient (they are aliased, but they would count as independent
# leaves). An alternate strategy would be to avoid implicitly
# detaching and instead "catch" gradients as they exit the
# PythonTensor boundary.
# assert not elem.requires_grad or not torch.is_grad_enabled()

r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work for meta tensors since at this point, elem will be a meta tensor. So we're just gonna make a PythonTensor with the meta device anyways.

That's why I went through all of the shenanigans of inferring the output device - if we run with meta tensors, then at no point do we have the actual output device of the operator. All you have is the device of the input tensors.

So... ripping out the device inference logic will make the meta-tracing stuff not work at all, in which case we should just remove all of it :P

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a bigger structural problem for meta tensors. Will need to think about this...

r.proxy = proxy
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
return r

def __repr__(self):
return f"PythonTensor({self.elem})"
# This is a bit goofy but whatever. Should fix up _tensor_str.py to
# work on subclasses when it calls tolist
return f"PythonTensor({torch.Tensor._make_subclass(torch.Tensor, self)})"

__torch_function__ = _disabled_torch_function_impl

Expand All @@ -99,14 +79,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap_proxy(e):
return e.proxy if isinstance(e, PythonTensor) else e

def unwrap_tensor(e):
return e.elem if isinstance(e, PythonTensor) else e

input_devices = [i.device for i in pytree.tree_flatten(args)[0] +
pytree.tree_flatten(kwargs)[0] if isinstance(i, torch.Tensor)]

output_device = get_output_device(input_devices, func)

proxy_args = pytree.tree_map(unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

Expand All @@ -115,16 +87,8 @@ def unwrap_tensor(e):
# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
args[0].proxy = proxy_out
args = pytree.tree_map(unwrap_tensor, args)
kwargs = pytree.tree_map(unwrap_tensor, kwargs)

try:
real_out = func(*args, **kwargs)
except NotImplementedError:
args = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
if isinstance(x, torch.Tensor) else x, args)
kwargs = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device)
if isinstance(x, torch.Tensor) else x, kwargs)
with no_dispatch():
real_out = func(*args, **kwargs)

def wrap_with_proxy(e, proxy):
Expand All @@ -135,7 +99,7 @@ def wrap_with_proxy(e, proxy):
if e is None:
e = torch.empty(())
if type(e) == torch.Tensor:
return PythonTensor(e, proxy, output_device)
return PythonTensor(e, proxy)
else:
return e
if isinstance(real_out, tuple):
Expand Down
2 changes: 1 addition & 1 deletion functorch/compile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .._src.operator_authoring import pointwise_operator
from .._src.python_key import pythonkey_decompose, pythonkey_meta
from .._src.python_key import pythonkey_decompose
from .._src.decompositions import register_decomposition, decomposition_table
from .._src.fx_minifier import minifier, check_nvfuser_subprocess
from .._src.aot_autograd import (
Expand Down
4 changes: 0 additions & 4 deletions test/test_pythonkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def f(x):
xfail('allclose'),
xfail('nn.functional.dropout'),
xfail('linalg.eigvals'),
xfail('nn.functional.ctc_loss'),
xfail('nn.functional.fractional_max_pool3d', device_type='cpu'),
xfail('randn_like'), # randomness
xfail('rand_like'), # randomness
Expand Down Expand Up @@ -355,11 +354,8 @@ class TestEagerFusionOpInfo(TestCase):
# entries in here need don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)
@skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', {
xfail('__rmatmul__'),
xfail('linalg.cholesky'),
xfail('matmul'),
skip('msort'),
xfail('nn.functional.linear'),
xfail('nn.functional.dropout'),
xfail('polar'),
xfail('special.zeta', 'grad'),
Expand Down