Skip to content

Commit c799599

Browse files
Moved function out of run_all_in_graph_and_eager_mode. (#1332)
1 parent c204b71 commit c799599

File tree

1 file changed

+26
-27
lines changed

1 file changed

+26
-27
lines changed

tensorflow_addons/optimizers/stochastic_weight_averaging_test.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -70,33 +70,6 @@ def test_averaging(self):
7070
self.assertAllClose(var_0.read_value(), [0.8, 0.8])
7171
self.assertAllClose(var_1.read_value(), [1.8, 1.8])
7272

73-
def test_fit_simple_linear_model(self):
74-
seed = 0x2019
75-
np.random.seed(seed)
76-
tf.random.set_seed(seed)
77-
num_examples = 100000
78-
x = np.random.standard_normal((num_examples, 3))
79-
w = np.random.standard_normal((3, 1))
80-
y = np.dot(x, w) + np.random.standard_normal((num_examples, 1)) * 1e-4
81-
82-
model = tf.keras.models.Sequential()
83-
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))
84-
# using num_examples - 1 since steps starts from 0.
85-
optimizer = SWA(
86-
"sgd", start_averaging=num_examples // 32 - 1, average_period=100
87-
)
88-
model.compile(optimizer, loss="mse")
89-
model.fit(x, y, epochs=2)
90-
optimizer.assign_average_vars(model.variables)
91-
92-
x = np.random.standard_normal((100, 3))
93-
y = np.dot(x, w)
94-
95-
predicted = model.predict(x)
96-
97-
max_abs_diff = np.max(np.abs(predicted - y))
98-
self.assertLess(max_abs_diff, 1e-3)
99-
10073
def test_optimizer_failure(self):
10174
with self.assertRaises(TypeError):
10275
_ = SWA(None, average_period=10)
@@ -128,5 +101,31 @@ def test_assign_batchnorm(self):
128101
fit_bn(model, x, y)
129102

130103

104+
def test_fit_simple_linear_model():
105+
seed = 0x2019
106+
np.random.seed(seed)
107+
tf.random.set_seed(seed)
108+
num_examples = 100000
109+
x = np.random.standard_normal((num_examples, 3))
110+
w = np.random.standard_normal((3, 1))
111+
y = np.dot(x, w) + np.random.standard_normal((num_examples, 1)) * 1e-4
112+
113+
model = tf.keras.models.Sequential()
114+
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))
115+
# using num_examples - 1 since steps starts from 0.
116+
optimizer = SWA("sgd", start_averaging=num_examples // 32 - 1, average_period=100)
117+
model.compile(optimizer, loss="mse")
118+
model.fit(x, y, epochs=2)
119+
optimizer.assign_average_vars(model.variables)
120+
121+
x = np.random.standard_normal((100, 3))
122+
y = np.dot(x, w)
123+
124+
predicted = model.predict(x)
125+
126+
max_abs_diff = np.max(np.abs(predicted - y))
127+
assert max_abs_diff < 1e-3
128+
129+
131130
if __name__ == "__main__":
132131
sys.exit(pytest.main([__file__]))

0 commit comments

Comments
 (0)