Skip to content

Add test_make_fx_model_train example #980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions test/test_pythonkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import unittest
import warnings
import itertools
import torch.nn.utils._stateless as stateless
from collections.abc import Iterable
from functools import partial
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from functorch import (
Expand Down Expand Up @@ -74,6 +76,53 @@ def f(x):
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))

def test_make_fx_model_fwd_bwd(self, device):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 5)

def forward(self, x):
return self.linear(x).relu()

model = Foo()

def f(x, params):
out = stateless.functional_call(model, params, x).sum()
out.backward()
return list(params.values())
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
fx_f = make_fx(f)(input, params)
# fx may change the order of parameters in list, so using set() to compare
self.assertEqual(set(fx_f(input, params)), set(f(input, params)))

def test_make_fx_model_fwd_bwd_wgtupdate(self, device):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 5)

def forward(self, x):
return self.linear(x).relu()

model = Foo()

def f(args, params, buffers):
if not isinstance(args, Iterable):
args = [args]
params_and_buffers = {**params, **buffers}
out = stateless.functional_call(model, params_and_buffers, args)
out.sum().backward()
return [p - 1e-4 * p.grad for p in params.values()]

input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())
fx_f = make_fx(f)(input, params, buffers)
# fx may change the order of parameters in list, so using set() to compare
self.assertEqual(set(fx_f(input, params, buffers)), set(f(input, params, buffers)))

def test_scalar_device(self, device):
def f(a, b):
return a + b
Expand Down