From 97c6845d7d426a7754cc291ef21c2830e5df5b8d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 25 Feb 2022 00:49:29 +0000 Subject: [PATCH] Trace the backward pass assuming contiguous tensors --- functorch/_src/aot_autograd.py | 7 +++---- test/test_pythonkey.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 88179c619..c65aff6bd 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -135,7 +135,7 @@ def forward(ctx, *flat_tensor_args): with torch.set_grad_enabled(grad_state): out = flat_fn(*flat_tensor_args) out = pytree.tree_map( - lambda x: x.detach() if isinstance(x, Tensor) else x, out + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out ) if isinstance(out, (list, tuple)): @@ -164,9 +164,8 @@ def forward(ctx, *flat_tensor_args): @staticmethod def backward(ctx, *flat_args): - # hmm... this doesn't feel right. todo - # contiguous_args = [t.contiguous() for t in flat_args] - contiguous_args = [t for t in flat_args] + contiguous_args = [t.contiguous() for t in flat_args] + # contiguous_args = [t for t in flat_args] out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) return tuple(out) diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index 70afd9ebb..be34c8ce4 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -527,6 +527,19 @@ def f(a, b, c, d): self.assertEqual(get_num_ins_outs(bw_graph), (2, 4)) +class TestContiguous(TestCase): + def test_contiguous(self): + # The test simulates the condition where transpose followed by view + # happens in the backward pass. + # https://discuss.pytorch.org/t/error-on-transpose-and-view/434 + def f(x): + return x.view(2, 3).t() + + inp = torch.randn(6, requires_grad=True) + out = aot_function(f, nop)(inp) + torch.autograd.grad(out, inp, torch.randn(3, 2)) + + only_for = ("cpu") instantiate_device_type_tests( TestPythonKey,