File tree 1 file changed +21
-0
lines changed 1 file changed +21
-0
lines changed Original file line number Diff line number Diff line change 11
11
import unittest
12
12
import warnings
13
13
import itertools
14
+ import torch .nn .utils ._stateless as stateless
14
15
from functools import partial
15
16
from torch .testing ._internal .common_device_type import instantiate_device_type_tests
16
17
from functorch import (
@@ -74,6 +75,26 @@ def f(x):
74
75
new_inp = torch .randn (3 )
75
76
self .assertEqual (fx_f (new_inp ), f (new_inp ))
76
77
78
+ def test_make_fx_model_train (self , device ):
79
+ class Foo (nn .Module ):
80
+ def __init__ (self ):
81
+ super ().__init__ ()
82
+ self .linear = nn .Linear (5 , 5 )
83
+
84
+ def forward (self , x ):
85
+ return self .linear (x ).relu ()
86
+ mod = Foo ()
87
+ optim = torch .optim .SGD (mod .parameters (), lr = 1e-4 )
88
+
89
+ def f (x , params ):
90
+ out = stateless .functional_call (mod , params , x ).sum ()
91
+ out .backward ()
92
+ return list (params .values ())
93
+ inp = torch .randn (3 , 5 , requires_grad = True )
94
+ params = dict (mod .named_parameters ())
95
+ fx_f = make_fx (f )(inp , params )
96
+ self .assertEqual (fx_f (inp , params ), f (inp , params ))
97
+
77
98
def test_scalar_device (self , device ):
78
99
def f (a , b ):
79
100
return a + b
You can’t perform that action at this time.
0 commit comments