File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -135,7 +135,7 @@ def forward(ctx, *flat_tensor_args):
135
135
with torch .set_grad_enabled (grad_state ):
136
136
out = flat_fn (* flat_tensor_args )
137
137
out = pytree .tree_map (
138
- lambda x : x .detach () if isinstance (x , Tensor ) else x , out
138
+ lambda x : x .detach (). contiguous () if isinstance (x , Tensor ) else x , out
139
139
)
140
140
141
141
if isinstance (out , (list , tuple )):
@@ -164,9 +164,8 @@ def forward(ctx, *flat_tensor_args):
164
164
165
165
@staticmethod
166
166
def backward (ctx , * flat_args ):
167
- # hmm... this doesn't feel right. todo
168
- # contiguous_args = [t.contiguous() for t in flat_args]
169
- contiguous_args = [t for t in flat_args ]
167
+ contiguous_args = [t .contiguous () for t in flat_args ]
168
+ # contiguous_args = [t for t in flat_args]
170
169
out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
171
170
return tuple (out )
172
171
You can’t perform that action at this time.
0 commit comments