|
4 | 4 | import torch
|
5 | 5 | import unittest
|
6 | 6 | import warnings
|
| 7 | +import torch.nn.utils._stateless as stateless |
| 8 | +from collections.abc import Iterable |
7 | 9 | from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
8 | 10 | from torch.testing._internal.common_methods_invocations import DecorateInfo
|
9 | 11 | from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
|
@@ -274,6 +276,72 @@ def fn(x):
|
274 | 276 |
|
275 | 277 | self.assertEqual(fx_module(x), decomposed_module(x))
|
276 | 278 |
|
| 279 | + def test_make_fx_model_fwd_bwd(self, device): |
| 280 | + class Foo(torch.nn.Module): |
| 281 | + def __init__(self): |
| 282 | + super().__init__() |
| 283 | + self.linear = torch.nn.Linear(5, 5) |
| 284 | + |
| 285 | + def forward(self, x): |
| 286 | + return self.linear(x).relu() |
| 287 | + |
| 288 | + model = Foo() |
| 289 | + |
| 290 | + def f(x, params): |
| 291 | + out = stateless.functional_call(model, params, x).sum() |
| 292 | + out.backward() |
| 293 | + return list(params.values()) |
| 294 | + input = torch.randn(3, 5, requires_grad=True) |
| 295 | + params = dict(model.named_parameters()) |
| 296 | + fx_f = make_fx(f)(input, params) |
| 297 | + # fx may change the order of parameters in list, so using set() to compare |
| 298 | + self.assertTrue( |
| 299 | + torch.allclose(fx_f(input, params)[0], f(input, params)[0]) |
| 300 | + or |
| 301 | + torch.allclose(fx_f(input, params)[0], f(input, params)[1]) |
| 302 | + ) |
| 303 | + self.assertTrue( |
| 304 | + torch.allclose(fx_f(input, params)[1], f(input, params)[0]) |
| 305 | + or |
| 306 | + torch.allclose(fx_f(input, params)[1], f(input, params)[1]) |
| 307 | + ) |
| 308 | + |
| 309 | + def test_make_fx_model_fwd_bwd_wgtupdate(self, device): |
| 310 | + class Foo(torch.nn.Module): |
| 311 | + def __init__(self): |
| 312 | + super().__init__() |
| 313 | + self.linear = torch.nn.Linear(5, 5) |
| 314 | + |
| 315 | + def forward(self, x): |
| 316 | + return self.linear(x).relu() |
| 317 | + |
| 318 | + model = Foo() |
| 319 | + |
| 320 | + def f(args, params, buffers): |
| 321 | + if not isinstance(args, Iterable): |
| 322 | + args = [args] |
| 323 | + params_and_buffers = {**params, **buffers} |
| 324 | + out = stateless.functional_call(model, params_and_buffers, args) |
| 325 | + out.sum().backward() |
| 326 | + return [p - 1e-4 * p.grad for p in params.values()] |
| 327 | + |
| 328 | + input = torch.randn(3, 5, requires_grad=True) |
| 329 | + params = dict(model.named_parameters()) |
| 330 | + buffers = dict(model.named_buffers()) |
| 331 | + fx_f = make_fx(f)(input, params, buffers) |
| 332 | + # fx may change the order of parameters in list, so using set() to compare |
| 333 | + # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03 |
| 334 | + self.assertTrue( |
| 335 | + torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03) |
| 336 | + or |
| 337 | + torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03) |
| 338 | + ) |
| 339 | + self.assertTrue( |
| 340 | + torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03) |
| 341 | + or |
| 342 | + torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03) |
| 343 | + ) |
| 344 | + |
277 | 345 | # TODO: Need to test the guards themselves specifically as well
|
278 | 346 | @skipIfNoSympy
|
279 | 347 | class TestSymbolicTracing(TestCase):
|
|
0 commit comments