Skip to content

Commit 0894c49

Browse files
mostafaelhoushipytorchmergebot
authored andcommitted
Add test_make_fx_model_train example (#980) (pytorch#82011)
Summary: Pull Request resolved: pytorch/functorch#980 Test Plan: CI should pass Differential Revision: D38078694 Pulled By: mostafaelhoushi Pull Request resolved: pytorch#82011 Approved by: https://github.com/Chillee
1 parent 02550bc commit 0894c49

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

test/test_proxy_tensor.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
import unittest
66
import warnings
7+
import torch.nn.utils._stateless as stateless
8+
from collections.abc import Iterable
79
from torch.testing._internal.common_device_type import instantiate_device_type_tests
810
from torch.testing._internal.common_methods_invocations import DecorateInfo
911
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
@@ -274,6 +276,72 @@ def fn(x):
274276

275277
self.assertEqual(fx_module(x), decomposed_module(x))
276278

279+
def test_make_fx_model_fwd_bwd(self, device):
280+
class Foo(torch.nn.Module):
281+
def __init__(self):
282+
super().__init__()
283+
self.linear = torch.nn.Linear(5, 5)
284+
285+
def forward(self, x):
286+
return self.linear(x).relu()
287+
288+
model = Foo()
289+
290+
def f(x, params):
291+
out = stateless.functional_call(model, params, x).sum()
292+
out.backward()
293+
return list(params.values())
294+
input = torch.randn(3, 5, requires_grad=True)
295+
params = dict(model.named_parameters())
296+
fx_f = make_fx(f)(input, params)
297+
# fx may change the order of parameters in list, so using set() to compare
298+
self.assertTrue(
299+
torch.allclose(fx_f(input, params)[0], f(input, params)[0])
300+
or
301+
torch.allclose(fx_f(input, params)[0], f(input, params)[1])
302+
)
303+
self.assertTrue(
304+
torch.allclose(fx_f(input, params)[1], f(input, params)[0])
305+
or
306+
torch.allclose(fx_f(input, params)[1], f(input, params)[1])
307+
)
308+
309+
def test_make_fx_model_fwd_bwd_wgtupdate(self, device):
310+
class Foo(torch.nn.Module):
311+
def __init__(self):
312+
super().__init__()
313+
self.linear = torch.nn.Linear(5, 5)
314+
315+
def forward(self, x):
316+
return self.linear(x).relu()
317+
318+
model = Foo()
319+
320+
def f(args, params, buffers):
321+
if not isinstance(args, Iterable):
322+
args = [args]
323+
params_and_buffers = {**params, **buffers}
324+
out = stateless.functional_call(model, params_and_buffers, args)
325+
out.sum().backward()
326+
return [p - 1e-4 * p.grad for p in params.values()]
327+
328+
input = torch.randn(3, 5, requires_grad=True)
329+
params = dict(model.named_parameters())
330+
buffers = dict(model.named_buffers())
331+
fx_f = make_fx(f)(input, params, buffers)
332+
# fx may change the order of parameters in list, so using set() to compare
333+
# also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
334+
self.assertTrue(
335+
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
336+
or
337+
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
338+
)
339+
self.assertTrue(
340+
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
341+
or
342+
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
343+
)
344+
277345
# TODO: Need to test the guards themselves specifically as well
278346
@skipIfNoSympy
279347
class TestSymbolicTracing(TestCase):

0 commit comments

Comments
 (0)