Skip to content

Commit f65f724

Browse files
committed
[DO NOT MERGE] [AOT Autograd] Trace the backward pass assuming contiguous tensors
1 parent c7d9acc commit f65f724

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

functorch/_src/aot_autograd.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def forward(ctx, *flat_tensor_args):
135135
with torch.set_grad_enabled(grad_state):
136136
out = flat_fn(*flat_tensor_args)
137137
out = pytree.tree_map(
138-
lambda x: x.detach() if isinstance(x, Tensor) else x, out
138+
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
139139
)
140140

141141
if isinstance(out, (list, tuple)):
@@ -164,9 +164,8 @@ def forward(ctx, *flat_tensor_args):
164164

165165
@staticmethod
166166
def backward(ctx, *flat_args):
167-
# hmm... this doesn't feel right. todo
168-
# contiguous_args = [t.contiguous() for t in flat_args]
169-
contiguous_args = [t for t in flat_args]
167+
contiguous_args = [t.contiguous() for t in flat_args]
168+
# contiguous_args = [t for t in flat_args]
170169
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
171170
return tuple(out)
172171

0 commit comments

Comments
 (0)