|
11 | 11 | import unittest
|
12 | 12 | import warnings
|
13 | 13 | import itertools
|
| 14 | +import torch.nn.utils._stateless as stateless |
| 15 | +from collections.abc import Iterable |
14 | 16 | from functools import partial
|
15 | 17 | from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
16 | 18 | from functorch import (
|
@@ -74,6 +76,83 @@ def f(x):
|
74 | 76 | new_inp = torch.randn(3)
|
75 | 77 | self.assertEqual(fx_f(new_inp), f(new_inp))
|
76 | 78 |
|
| 79 | + def test_make_fx_model_fwd_bwd(self, device): |
| 80 | + class Foo(nn.Module): |
| 81 | + def __init__(self): |
| 82 | + super().__init__() |
| 83 | + self.linear = nn.Linear(5, 5) |
| 84 | + |
| 85 | + def forward(self, x): |
| 86 | + return self.linear(x).relu() |
| 87 | + |
| 88 | + model = Foo() |
| 89 | + |
| 90 | + def f(x, params): |
| 91 | + out = stateless.functional_call(model, params, x).sum() |
| 92 | + out.backward() |
| 93 | + return list(params.values()) |
| 94 | + input = torch.randn(3, 5, requires_grad=True) |
| 95 | + params = dict(model.named_parameters()) |
| 96 | + fx_f = make_fx(f)(input, params) |
| 97 | + # TODO: what assert statement should we add here? |
| 98 | + assert(fx_f(input, params) is not None) |
| 99 | + |
| 100 | + def test_make_fx_model_fwd_bwd_wgtupdate(self, device): |
| 101 | + class Foo(nn.Module): |
| 102 | + def __init__(self): |
| 103 | + super().__init__() |
| 104 | + self.linear = nn.Linear(5, 5) |
| 105 | + |
| 106 | + def forward(self, x): |
| 107 | + return self.linear(x).relu() |
| 108 | + |
| 109 | + model = Foo() |
| 110 | + |
| 111 | + def f(args, params, buffers): |
| 112 | + if not isinstance(args, Iterable): |
| 113 | + args = [args] |
| 114 | + params_and_buffers = {**params, **buffers} |
| 115 | + out = stateless.functional_call(model, params_and_buffers, args) |
| 116 | + out.sum().backward() |
| 117 | + return [p - 1e-4 * p.grad for p in params.values()] |
| 118 | + |
| 119 | + input = torch.randn(3, 5, requires_grad=True) |
| 120 | + params = dict(model.named_parameters()) |
| 121 | + buffers = dict(model.named_buffers()) |
| 122 | + fx_f = make_fx(f)(input, params, buffers) |
| 123 | + # TODO: what assert statement should we add here? |
| 124 | + assert(fx_f(input, params, buffers) is not None) |
| 125 | + |
| 126 | + def test_make_fx_model_train_with_optim(self, device): |
| 127 | + class Foo(nn.Module): |
| 128 | + def __init__(self): |
| 129 | + super().__init__() |
| 130 | + self.linear = nn.Linear(5, 5) |
| 131 | + |
| 132 | + def forward(self, x): |
| 133 | + return self.linear(x).relu() |
| 134 | + |
| 135 | + model = Foo() |
| 136 | + optim = torch.optim.SGD(model.parameters(), lr=1e-4) |
| 137 | + |
| 138 | + def f(args, params, buffers): |
| 139 | + if not isinstance(args, Iterable): |
| 140 | + args = [args] |
| 141 | + params_and_buffers = {**params, **buffers} |
| 142 | + out = stateless.functional_call(model, params_and_buffers, args) |
| 143 | + out.sum().backward() |
| 144 | + optim.step() |
| 145 | + |
| 146 | + # TODO: this causes graph to show an output with many incoming edges. Shall we try `return None` or simply don't return? |
| 147 | + return list(params.values()) |
| 148 | + |
| 149 | + input = torch.randn(3, 5, requires_grad=True) |
| 150 | + params = dict(model.named_parameters()) |
| 151 | + buffers = dict(model.named_buffers()) |
| 152 | + fx_f = make_fx(f)(input, params, buffers) |
| 153 | + # TODO: what assert statement should we add here? |
| 154 | + assert(fx_f(input, params, buffers) is not None) |
| 155 | + |
77 | 156 | def test_scalar_device(self, device):
|
78 | 157 | def f(a, b):
|
79 | 158 | return a + b
|
|
0 commit comments