@@ -70,33 +70,6 @@ def test_averaging(self):
70
70
self .assertAllClose (var_0 .read_value (), [0.8 , 0.8 ])
71
71
self .assertAllClose (var_1 .read_value (), [1.8 , 1.8 ])
72
72
73
- def test_fit_simple_linear_model (self ):
74
- seed = 0x2019
75
- np .random .seed (seed )
76
- tf .random .set_seed (seed )
77
- num_examples = 100000
78
- x = np .random .standard_normal ((num_examples , 3 ))
79
- w = np .random .standard_normal ((3 , 1 ))
80
- y = np .dot (x , w ) + np .random .standard_normal ((num_examples , 1 )) * 1e-4
81
-
82
- model = tf .keras .models .Sequential ()
83
- model .add (tf .keras .layers .Dense (input_shape = (3 ,), units = 1 ))
84
- # using num_examples - 1 since steps starts from 0.
85
- optimizer = SWA (
86
- "sgd" , start_averaging = num_examples // 32 - 1 , average_period = 100
87
- )
88
- model .compile (optimizer , loss = "mse" )
89
- model .fit (x , y , epochs = 2 )
90
- optimizer .assign_average_vars (model .variables )
91
-
92
- x = np .random .standard_normal ((100 , 3 ))
93
- y = np .dot (x , w )
94
-
95
- predicted = model .predict (x )
96
-
97
- max_abs_diff = np .max (np .abs (predicted - y ))
98
- self .assertLess (max_abs_diff , 1e-3 )
99
-
100
73
def test_optimizer_failure (self ):
101
74
with self .assertRaises (TypeError ):
102
75
_ = SWA (None , average_period = 10 )
@@ -128,5 +101,31 @@ def test_assign_batchnorm(self):
128
101
fit_bn (model , x , y )
129
102
130
103
104
+ def test_fit_simple_linear_model ():
105
+ seed = 0x2019
106
+ np .random .seed (seed )
107
+ tf .random .set_seed (seed )
108
+ num_examples = 100000
109
+ x = np .random .standard_normal ((num_examples , 3 ))
110
+ w = np .random .standard_normal ((3 , 1 ))
111
+ y = np .dot (x , w ) + np .random .standard_normal ((num_examples , 1 )) * 1e-4
112
+
113
+ model = tf .keras .models .Sequential ()
114
+ model .add (tf .keras .layers .Dense (input_shape = (3 ,), units = 1 ))
115
+ # using num_examples - 1 since steps starts from 0.
116
+ optimizer = SWA ("sgd" , start_averaging = num_examples // 32 - 1 , average_period = 100 )
117
+ model .compile (optimizer , loss = "mse" )
118
+ model .fit (x , y , epochs = 2 )
119
+ optimizer .assign_average_vars (model .variables )
120
+
121
+ x = np .random .standard_normal ((100 , 3 ))
122
+ y = np .dot (x , w )
123
+
124
+ predicted = model .predict (x )
125
+
126
+ max_abs_diff = np .max (np .abs (predicted - y ))
127
+ assert max_abs_diff < 1e-3
128
+
129
+
131
130
if __name__ == "__main__" :
132
131
sys .exit (pytest .main ([__file__ ]))
0 commit comments