Skip to content

Commit 2a852dd

Browse files
Add test_make_fx_model_train example (#82011)
Summary: X-link: pytorch/pytorch#82011 Pull Request resolved: #980 Test Plan: CI should pass Reviewed By: benoitsteiner Differential Revision: D38078694 Pulled By: mostafaelhoushi fbshipit-source-id: ccfbeb8531d49d0d503e728f997a6003c87f9eb1
1 parent e8a68f4 commit 2a852dd

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

test/test_pythonkey.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import unittest
1212
import warnings
1313
import itertools
14+
import torch.nn.utils._stateless as stateless
15+
from collections.abc import Iterable
1416
from functools import partial
1517
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1618
from functorch import (
@@ -74,6 +76,53 @@ def f(x):
7476
new_inp = torch.randn(3)
7577
self.assertEqual(fx_f(new_inp), f(new_inp))
7678

79+
def test_make_fx_model_fwd_bwd(self, device):
80+
class Foo(nn.Module):
81+
def __init__(self):
82+
super().__init__()
83+
self.linear = nn.Linear(5, 5)
84+
85+
def forward(self, x):
86+
return self.linear(x).relu()
87+
88+
model = Foo()
89+
90+
def f(x, params):
91+
out = stateless.functional_call(model, params, x).sum()
92+
out.backward()
93+
return list(params.values())
94+
input = torch.randn(3, 5, requires_grad=True)
95+
params = dict(model.named_parameters())
96+
fx_f = make_fx(f)(input, params)
97+
# fx may change the order of parameters in list, so using set() to compare
98+
self.assertEqual(set(fx_f(input, params)), set(f(input, params)))
99+
100+
def test_make_fx_model_fwd_bwd_wgtupdate(self, device):
101+
class Foo(nn.Module):
102+
def __init__(self):
103+
super().__init__()
104+
self.linear = nn.Linear(5, 5)
105+
106+
def forward(self, x):
107+
return self.linear(x).relu()
108+
109+
model = Foo()
110+
111+
def f(args, params, buffers):
112+
if not isinstance(args, Iterable):
113+
args = [args]
114+
params_and_buffers = {**params, **buffers}
115+
out = stateless.functional_call(model, params_and_buffers, args)
116+
out.sum().backward()
117+
return [p - 1e-4 * p.grad for p in params.values()]
118+
119+
input = torch.randn(3, 5, requires_grad=True)
120+
params = dict(model.named_parameters())
121+
buffers = dict(model.named_buffers())
122+
fx_f = make_fx(f)(input, params, buffers)
123+
# fx may change the order of parameters in list, so using set() to compare
124+
self.assertEqual(set(fx_f(input, params, buffers)), set(f(input, params, buffers)))
125+
77126
def test_scalar_device(self, device):
78127
def f(a, b):
79128
return a + b

0 commit comments

Comments
 (0)