Skip to content

Commit 1433db5

Browse files
Add test_make_fx_model_train example (#980)
Summary: Pull Request resolved: #980 Differential Revision: D38078694 Pulled By: mostafaelhoushi fbshipit-source-id: 4e8be4c653ac2c871d2e47064644b9a965e05187
1 parent e8a68f4 commit 1433db5

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

test/test_pythonkey.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import unittest
1212
import warnings
1313
import itertools
14+
import torch.nn.utils._stateless as stateless
1415
from functools import partial
1516
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1617
from functorch import (
@@ -74,6 +75,26 @@ def f(x):
7475
new_inp = torch.randn(3)
7576
self.assertEqual(fx_f(new_inp), f(new_inp))
7677

78+
def test_make_fx_model_train(self, device):
79+
class Foo(nn.Module):
80+
def __init__(self):
81+
super().__init__()
82+
self.linear = nn.Linear(5, 5)
83+
84+
def forward(self, x):
85+
return self.linear(x).relu()
86+
mod = Foo()
87+
optim = torch.optim.SGD(mod.parameters(), lr=1e-4)
88+
89+
def f(x, params):
90+
out = stateless.functional_call(mod, params, x).sum()
91+
out.backward()
92+
return list(params.values())
93+
inp = torch.randn(3, 5, requires_grad=True)
94+
params = dict(mod.named_parameters())
95+
fx_f = make_fx(f)(inp, params)
96+
self.assertEqual(fx_f(inp, params), f(inp, params))
97+
7798
def test_scalar_device(self, device):
7899
def f(a, b):
79100
return a + b

0 commit comments

Comments
 (0)