Skip to content

Commit 991cb15

Browse files
bdhirshzou3519
authored andcommitted
[functorch] fix functionalize(): properly propagate input mutations (pytorch/functorch#654)
1 parent 99946f3 commit 991cb15

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

functorch/functorch/_src/eager_transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,10 +1224,10 @@ def wrapped(*args, **kwargs):
12241224
func_args = _wrap_all_tensors_to_functional(args, func_level)
12251225
func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level)
12261226

1227-
flattened_unwrapped_args = tree_flatten(args)
1228-
flattened_wrapped_args = tree_flatten(func_args)
1229-
flattened_unwrapped_kwargs = tree_flatten(kwargs)
1230-
flattened_wrapped_kwargs = tree_flatten(func_kwargs)
1227+
flattened_unwrapped_args, _ = tree_flatten(args)
1228+
flattened_wrapped_args, _ = tree_flatten(func_args)
1229+
flattened_unwrapped_kwargs, _ = tree_flatten(kwargs)
1230+
flattened_wrapped_kwargs, _ = tree_flatten(func_kwargs)
12311231

12321232
func_outputs = func(*func_args, **func_kwargs)
12331233
outputs = _unwrap_all_tensors_from_functional(func_outputs)

functorch/test/test_eager_transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2776,6 +2776,9 @@ def _check_functionalize_correctness(self, f, inpt):
27762776
# Check that outputs are the same
27772777
self.assertEqual(actual_outputs, expected_outputs)
27782778

2779+
# Inputs might have been mutated by f: check that they were mutated properly
2780+
self.assertEqual(inpt1, inpt2)
2781+
27792782
def test_simple_view(self, device):
27802783
def f(x: torch.Tensor) -> torch.Tensor:
27812784
tmp = torch.ones(2, device=device)

0 commit comments

Comments
 (0)