Skip to content

Commit 1f1dfa9

Browse files
ydwu4pytorchmergebot
authored andcommitted
Fix grad higher order handling TupleVariable (pytorch#106425)
Previously, we assume the argnums is a **ConstantVariable**. However I accidentally triggered an error on CI where argnums could be a **TupleVariable**. In that case, we have an attribute error when access the .value of argnums. This PR adds support for the TupleVariable. It allows the unit test to pass without falling back to eager "PYTORCH_TEST_WITH_DYNAMO=1 python test/functorch/test_eager_transforms.py -k test_argnums_cpu" Test Plan: see modified test. Pull Request resolved: pytorch#106425 Approved by: https://github.com/yanboliang, https://github.com/anijain2305, https://github.com/kshitij12345
1 parent f998869 commit 1f1dfa9

File tree

2 files changed

+51
-17
lines changed

2 files changed

+51
-17
lines changed

test/dynamo/test_higher_order_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,15 +1851,21 @@ def forward(self, l_x_, l_y_):
18511851
def test_grad_two_tensor_all_grad_has_aux(self):
18521852
counters.clear()
18531853

1854+
nums = (0, 1)
1855+
18541856
def fn(x, y):
18551857
return ((x.sin() + y).sum(), x.cos())
18561858

1857-
def wrapper_fn(x, y):
1859+
def wrapper_fn_const_var(x, y):
18581860
return torch.func.grad(fn, argnums=(0, 1), has_aux=True)(x, y)
18591861

1862+
def wrapper_fn_tuple_var(x, y):
1863+
return torch.func.grad(fn, argnums=nums, has_aux=True)(x, y)
1864+
18601865
y = torch.randn(3, 3, 3)
18611866
x = torch.randn(3, 3, 3)
1862-
wrapped_gm = self._grad_compile_check(wrapper_fn, (x, y))
1867+
wrapped_gm_const_var = self._grad_compile_check(wrapper_fn_const_var, (x, y))
1868+
wrapped_gm_tuple_var = self._grad_compile_check(wrapper_fn_tuple_var, (x, y))
18631869

18641870
# Dynamic shapes produce a slightly different graph.
18651871
if check_dynamic_shape_capture():
@@ -1894,8 +1900,14 @@ def forward(self, l_x_, l_y_):
18941900
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
18951901
return (sum_1, cos)
18961902
"""
1897-
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
1898-
self.assertExpectedInline(actual, expected)
1903+
actual_const_var = normalize_gm(
1904+
wrapped_gm_const_var.print_readable(print_output=False)
1905+
)
1906+
actual_tuple_var = normalize_gm(
1907+
wrapped_gm_tuple_var.print_readable(print_output=False)
1908+
)
1909+
self.assertExpectedInline(actual_const_var, expected)
1910+
self.assertExpectedInline(actual_tuple_var, expected)
18991911

19001912
def test_grad_over_grad(self):
19011913
counters.clear()

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -666,15 +666,37 @@ def call_function(
666666
# For has_aux=False, Tuple[gradients of inputs indicated by argnums].
667667
# For has_aux=True, Tuple[Tuple[gradients of inputs indicated by argnums], aux values]
668668
# NOTE: example_value should match `grad_output`.
669-
if isinstance(argnums.value, int):
670-
example_value = (
671-
args[argnums.value].as_proxy().node.meta["example_value"].contiguous()
672-
)
673-
else:
674-
example_value = tuple(
675-
args[idx].as_proxy().node.meta["example_value"].contiguous()
676-
for idx in argnums.value
677-
)
669+
def _from_args(idx):
670+
return args[idx].as_proxy().node.meta["example_value"].contiguous()
671+
672+
def to_python_ints(argnums):
673+
if not isinstance(argnums, (ConstantVariable, TupleVariable)):
674+
raise UserError(
675+
UserErrorType.INVALID_INPUT,
676+
f"argnums is expected to be int or tuple of ints. Got {argnums}.",
677+
)
678+
679+
if isinstance(argnums, ConstantVariable):
680+
if not isinstance(argnums.value, (int, tuple)):
681+
raise UserError(
682+
UserErrorType.INVALID_INPUT,
683+
f"argnums is expected to be int or tuple of ints. Got {argnums}.",
684+
)
685+
return argnums.value
686+
else:
687+
const_vars = argnums.unpack_var_sequence(tx)
688+
if not all(
689+
isinstance(var, ConstantVariable) and isinstance(var.value, int)
690+
for var in const_vars
691+
):
692+
raise UserError(
693+
UserErrorType.INVALID_INPUT,
694+
f"argnums is expected to contain int only. Got {const_vars}.",
695+
)
696+
return tuple(var.value for var in const_vars)
697+
698+
argnums_v = to_python_ints(argnums)
699+
example_value = pytree.tree_map(_from_args, argnums_v)
678700

679701
if has_aux.value:
680702
# case : has_aux = True
@@ -691,12 +713,12 @@ def call_function(
691713

692714
# Call contiguous on all the computed grads.
693715
if not has_aux.value:
694-
if isinstance(argnums.value, int):
716+
if isinstance(argnums_v, int):
695717
return fx_proxy.call_method(tx, "contiguous", (), {})
696718
else:
697719
grads = fx_proxy
698720
items = []
699-
for idx in range(len(argnums.value)):
721+
for idx in range(len(argnums_v)):
700722
proxy = grads.call_method(
701723
tx, "__getitem__", (ConstantVariable(idx),), {}
702724
).call_method(tx, "contiguous", (), {})
@@ -706,11 +728,11 @@ def call_function(
706728
# fx_proxy -> Tuple(grads, aux)
707729
grads = fx_proxy.call_method(tx, "__getitem__", (ConstantVariable(0),), {})
708730
aux = fx_proxy.call_method(tx, "__getitem__", (ConstantVariable(1),), {})
709-
if isinstance(argnums.value, int):
731+
if isinstance(argnums_v, int):
710732
return TupleVariable([grads.call_method(tx, "contiguous", (), {}), aux])
711733
else:
712734
items = []
713-
for idx in range(len(argnums.value)):
735+
for idx in range(len(argnums_v)):
714736
proxy = grads.call_method(
715737
tx, "__getitem__", (ConstantVariable(idx),), {}
716738
).call_method(tx, "contiguous", (), {})

0 commit comments

Comments
 (0)