Skip to content

Commit 0c10cfc

Browse files
Add test_make_fx_model_train example (#82011)
Summary: X-link: pytorch/pytorch#82011 Pull Request resolved: #980 Test Plan: CI should pass Reviewed By: benoitsteiner Differential Revision: D38078694 Pulled By: mostafaelhoushi fbshipit-source-id: bce8ccebce3b604dc5b98168b783aa6e21938f95
1 parent e8a68f4 commit 0c10cfc

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

test/test_pythonkey.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import unittest
1212
import warnings
1313
import itertools
14+
import torch.nn.utils._stateless as stateless
15+
from collections.abc import Iterable
1416
from functools import partial
1517
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1618
from functorch import (
@@ -74,6 +76,83 @@ def f(x):
7476
new_inp = torch.randn(3)
7577
self.assertEqual(fx_f(new_inp), f(new_inp))
7678

79+
def test_make_fx_model_fwd_bwd(self, device):
80+
class Foo(nn.Module):
81+
def __init__(self):
82+
super().__init__()
83+
self.linear = nn.Linear(5, 5)
84+
85+
def forward(self, x):
86+
return self.linear(x).relu()
87+
88+
model = Foo()
89+
90+
def f(x, params):
91+
out = stateless.functional_call(model, params, x).sum()
92+
out.backward()
93+
return list(params.values())
94+
input = torch.randn(3, 5, requires_grad=True)
95+
params = dict(model.named_parameters())
96+
fx_f = make_fx(f)(input, params)
97+
# TODO: what assert statement should we add here?
98+
assert(fx_f(input, params) is not None)
99+
100+
def test_make_fx_model_fwd_bwd_wgtupdate(self, device):
101+
class Foo(nn.Module):
102+
def __init__(self):
103+
super().__init__()
104+
self.linear = nn.Linear(5, 5)
105+
106+
def forward(self, x):
107+
return self.linear(x).relu()
108+
109+
model = Foo()
110+
111+
def f(args, params, buffers):
112+
if not isinstance(args, Iterable):
113+
args = [args]
114+
params_and_buffers = {**params, **buffers}
115+
out = stateless.functional_call(model, params_and_buffers, args)
116+
out.sum().backward()
117+
return [p - 1e-4 * p.grad for p in params.values()]
118+
119+
input = torch.randn(3, 5, requires_grad=True)
120+
params = dict(model.named_parameters())
121+
buffers = dict(model.named_buffers())
122+
fx_f = make_fx(f)(input, params, buffers)
123+
# TODO: what assert statement should we add here?
124+
assert(fx_f(input, params, buffers) is not None)
125+
126+
def test_make_fx_model_train_with_optim(self, device):
127+
class Foo(nn.Module):
128+
def __init__(self):
129+
super().__init__()
130+
self.linear = nn.Linear(5, 5)
131+
132+
def forward(self, x):
133+
return self.linear(x).relu()
134+
135+
model = Foo()
136+
optim = torch.optim.SGD(model.parameters(), lr=1e-4)
137+
138+
def f(args, params, buffers):
139+
if not isinstance(args, Iterable):
140+
args = [args]
141+
params_and_buffers = {**params, **buffers}
142+
out = stateless.functional_call(model, params_and_buffers, args)
143+
out.sum().backward()
144+
optim.step()
145+
146+
# TODO: this causes graph to show an output with many incoming edges. Shall we try `return None` or simply don't return?
147+
return list(params.values())
148+
149+
input = torch.randn(3, 5, requires_grad=True)
150+
params = dict(model.named_parameters())
151+
buffers = dict(model.named_buffers())
152+
fx_f = make_fx(f)(input, params, buffers)
153+
# TODO: what assert statement should we add here?
154+
assert(fx_f(input, params, buffers) is not None)
155+
77156
def test_scalar_device(self, device):
78157
def f(a, b):
79158
return a + b

0 commit comments

Comments
 (0)