diff --git a/functorch/_src/python_key.py b/functorch/_src/python_key.py index 3532f3226..92c87245a 100644 --- a/functorch/_src/python_key.py +++ b/functorch/_src/python_key.py @@ -16,43 +16,25 @@ aten = torch.ops.aten CURRENT_DECOMPOSITION_TABLE = {} -USE_META = False @contextmanager -def pythonkey_decompose(decomposition_table): - global CURRENT_DECOMPOSITION_TABLE - CURRENT_DECOMPOSITION_TABLE = decomposition_table +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() try: - yield CURRENT_DECOMPOSITION_TABLE + yield finally: - CURRENT_DECOMPOSITION_TABLE = {} + del guard @contextmanager -def pythonkey_meta(): - global USE_META - USE_META = True +def pythonkey_decompose(decomposition_table): + global CURRENT_DECOMPOSITION_TABLE + CURRENT_DECOMPOSITION_TABLE = decomposition_table try: - yield USE_META + yield CURRENT_DECOMPOSITION_TABLE finally: - USE_META = False - - -def get_output_device(devices, op): - # The device propagation is a bit sketchy. - # aten::index(CPU, CUDA) => CPU tensor - # aten::index(CUDA, CPU) => CUDA tensor - if op == aten.index: - return devices[0] - devices = list(set(devices)) - if len(devices) == 1: - return devices[0] - else: - for device in devices: - if device.type == 'cuda': - return device - raise RuntimeError("Couldn't infer output device from input device") + CURRENT_DECOMPOSITION_TABLE = {} class PythonTensor(torch.Tensor): @@ -61,29 +43,27 @@ class PythonTensor(torch.Tensor): __slots__ = ['elem', 'proxy'] @staticmethod - def __new__(cls, elem, proxy, device=None): - # The wrapping tensor (PythonTensor) is just a meta tensor, so it - # doesn't hold any memory (meta tensor is generally the preferred type - # of tensor you want to make a subclass from)... - - r = torch.Tensor._make_wrapper_subclass( - cls, elem.size(), - strides=elem.stride(), storage_offset=elem.storage_offset(), - dtype=elem.dtype, layout=elem.layout, requires_grad=elem.requires_grad, - device=(elem.device if device is None else device), - ) - - # ...the real tensor is held as an element on the tensor. - if USE_META: - r.elem = elem.to('meta') - else: - r.elem = elem + def __new__(cls, elem, proxy): + # Wrapping something in PythonTensor implicitly detaches + # gradients. If something required grad, we will collect it as if it + # were a leaf. A consequence of detaching in this way is you + # need to maintain a parameter cache when translating tensors + # into PythonTensor, so you don't create multiple copies of + # a gradient (they are aliased, but they would count as independent + # leaves). An alternate strategy would be to avoid implicitly + # detaching and instead "catch" gradients as they exit the + # PythonTensor boundary. + # assert not elem.requires_grad or not torch.is_grad_enabled() + + r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) r.proxy = proxy proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) return r def __repr__(self): - return f"PythonTensor({self.elem})" + # This is a bit goofy but whatever. Should fix up _tensor_str.py to + # work on subclasses when it calls tolist + return f"PythonTensor({torch.Tensor._make_subclass(torch.Tensor, self)})" __torch_function__ = _disabled_torch_function_impl @@ -99,14 +79,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap_proxy(e): return e.proxy if isinstance(e, PythonTensor) else e - def unwrap_tensor(e): - return e.elem if isinstance(e, PythonTensor) else e - - input_devices = [i.device for i in pytree.tree_flatten(args)[0] + - pytree.tree_flatten(kwargs)[0] if isinstance(i, torch.Tensor)] - - output_device = get_output_device(input_devices, func) - proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) @@ -115,16 +87,8 @@ def unwrap_tensor(e): # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_out - args = pytree.tree_map(unwrap_tensor, args) - kwargs = pytree.tree_map(unwrap_tensor, kwargs) - try: - real_out = func(*args, **kwargs) - except NotImplementedError: - args = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device) - if isinstance(x, torch.Tensor) else x, args) - kwargs = pytree.tree_map(lambda x: torch.ones_like(x, device=output_device) - if isinstance(x, torch.Tensor) else x, kwargs) + with no_dispatch(): real_out = func(*args, **kwargs) def wrap_with_proxy(e, proxy): @@ -135,7 +99,7 @@ def wrap_with_proxy(e, proxy): if e is None: e = torch.empty(()) if type(e) == torch.Tensor: - return PythonTensor(e, proxy, output_device) + return PythonTensor(e, proxy) else: return e if isinstance(real_out, tuple): diff --git a/functorch/compile/__init__.py b/functorch/compile/__init__.py index cb4a1eca0..95e161674 100644 --- a/functorch/compile/__init__.py +++ b/functorch/compile/__init__.py @@ -1,5 +1,5 @@ from .._src.operator_authoring import pointwise_operator -from .._src.python_key import pythonkey_decompose, pythonkey_meta +from .._src.python_key import pythonkey_decompose from .._src.decompositions import register_decomposition, decomposition_table from .._src.fx_minifier import minifier, check_nvfuser_subprocess from .._src.aot_autograd import ( diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index 40b93dc47..4cf31ec64 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -192,7 +192,6 @@ def f(x): xfail('allclose'), xfail('nn.functional.dropout'), xfail('linalg.eigvals'), - xfail('nn.functional.ctc_loss'), xfail('nn.functional.fractional_max_pool3d', device_type='cpu'), xfail('randn_like'), # randomness xfail('rand_like'), # randomness @@ -355,11 +354,8 @@ class TestEagerFusionOpInfo(TestCase): # entries in here need don't work and need to be fixed. # Each one of these is a bug (or needs to be investigated) @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', { - xfail('__rmatmul__'), xfail('linalg.cholesky'), - xfail('matmul'), skip('msort'), - xfail('nn.functional.linear'), xfail('nn.functional.dropout'), xfail('polar'), xfail('special.zeta', 'grad'),