|
11 | 11 | import unittest
|
12 | 12 | import warnings
|
13 | 13 | import itertools
|
| 14 | +import torch.nn.utils._stateless as stateless |
| 15 | +from collections.abc import Iterable |
14 | 16 | from functools import partial
|
15 | 17 | from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
16 | 18 | from functorch import (
|
@@ -74,6 +76,53 @@ def f(x):
|
74 | 76 | new_inp = torch.randn(3)
|
75 | 77 | self.assertEqual(fx_f(new_inp), f(new_inp))
|
76 | 78 |
|
| 79 | + def test_make_fx_model_fwd_bwd(self, device): |
| 80 | + class Foo(nn.Module): |
| 81 | + def __init__(self): |
| 82 | + super().__init__() |
| 83 | + self.linear = nn.Linear(5, 5) |
| 84 | + |
| 85 | + def forward(self, x): |
| 86 | + return self.linear(x).relu() |
| 87 | + |
| 88 | + model = Foo() |
| 89 | + |
| 90 | + def f(x, params): |
| 91 | + out = stateless.functional_call(model, params, x).sum() |
| 92 | + out.backward() |
| 93 | + return list(params.values()) |
| 94 | + input = torch.randn(3, 5, requires_grad=True) |
| 95 | + params = dict(model.named_parameters()) |
| 96 | + fx_f = make_fx(f)(input, params) |
| 97 | + # fx may change the order of parameters in list, so using set() to compare |
| 98 | + self.assertEqual(set(fx_f(input, params)), set(f(input, params))) |
| 99 | + |
| 100 | + def test_make_fx_model_fwd_bwd_wgtupdate(self, device): |
| 101 | + class Foo(nn.Module): |
| 102 | + def __init__(self): |
| 103 | + super().__init__() |
| 104 | + self.linear = nn.Linear(5, 5) |
| 105 | + |
| 106 | + def forward(self, x): |
| 107 | + return self.linear(x).relu() |
| 108 | + |
| 109 | + model = Foo() |
| 110 | + |
| 111 | + def f(args, params, buffers): |
| 112 | + if not isinstance(args, Iterable): |
| 113 | + args = [args] |
| 114 | + params_and_buffers = {**params, **buffers} |
| 115 | + out = stateless.functional_call(model, params_and_buffers, args) |
| 116 | + out.sum().backward() |
| 117 | + return [p - 1e-4 * p.grad for p in params.values()] |
| 118 | + |
| 119 | + input = torch.randn(3, 5, requires_grad=True) |
| 120 | + params = dict(model.named_parameters()) |
| 121 | + buffers = dict(model.named_buffers()) |
| 122 | + fx_f = make_fx(f)(input, params, buffers) |
| 123 | + # fx may change the order of parameters in list, so using set() to compare |
| 124 | + self.assertEqual(set(fx_f(input, params, buffers)), set(f(input, params, buffers))) |
| 125 | + |
77 | 126 | def test_scalar_device(self, device):
|
78 | 127 | def f(a, b):
|
79 | 128 | return a + b
|
|
0 commit comments