Skip to content

Commit addaa4a

Browse files
committed
Add ability to swap weights to MovingAverage.
This patch makes it easier to swap the model weights and the MovingAverage weights before eval and swap them back after eval.
1 parent 0264703 commit addaa4a

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

tensorflow_addons/optimizers/moving_average.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,57 @@ def _create_slots(self, var_list):
124124
) # pylint: disable=protected-access
125125
for var in var_list:
126126
self.add_slot(var, "average", var.read_value())
127+
128+
def shadow_copy(self, model_weights):
129+
"""Creates shadow variables for the given model weights."""
130+
for var in model_weights:
131+
self.add_slot(var, "average", initializer="zeros")
132+
self._average_weights = [self.get_slot(var, "average") for var in model_weights]
133+
self._model_weights = model_weights
134+
135+
@property
136+
def has_shadow_copy(self):
137+
"""Whether this optimizer has created shadow variables."""
138+
return self._model_weights is not None
139+
140+
def swap_weights(self):
141+
"""Swap the average and moving weights.
142+
143+
This is a convenience method to allow one to evaluate the averaged weights
144+
at test time. Loads the weights stored in `self._average` into the model,
145+
keeping a copy of the original model weights. Swapping twice will return
146+
the original weights.
147+
"""
148+
if tf.distribute.in_cross_replica_context():
149+
strategy = tf.distribute.get_strategy()
150+
return strategy.run(self._swap_weights, args=())
151+
else:
152+
raise ValueError(
153+
"Swapping weights must occur under a " "tf.distribute.Strategy"
154+
)
155+
156+
@tf.function
157+
def _swap_weights(self):
158+
def fn_0(a, b):
159+
a.assign_add(b)
160+
return a
161+
162+
def fn_1(b, a):
163+
b.assign(a - b)
164+
return b
165+
166+
def fn_2(a, b):
167+
a.assign_sub(b)
168+
return a
169+
170+
def swap(strategy, a_and_b):
171+
"""Swap `a` and `b` and mirror to all devices."""
172+
for a, b in a_and_b:
173+
strategy.extended.update(a, fn_0, args=(b,)) # a = a + b
174+
strategy.extended.update(b, fn_1, args=(a,)) # b = a - b
175+
strategy.extended.update(a, fn_2, args=(b,)) # a = a - b
176+
177+
ctx = tf.distribute.get_replica_context()
178+
return ctx.merge_call(
179+
swap, args=(zip(self._average_weights, self._model_weights),)
180+
)

tensorflow_addons/optimizers/tests/moving_average_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,48 @@ def test_dynamic_decay(sequential_update):
247247
ema_var0 = opt.get_slot(var0, "average")
248248
if sequential_update:
249249
np.testing.assert_allclose(ema_var0.read_value(), [0.64, 1.64])
250+
251+
252+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
253+
@pytest.mark.parametrize("sequential_update", [True, False])
254+
def test_swap_weights(sequential_update):
255+
for sequential_update in [True, False]:
256+
257+
strategy = tf.distribute.OneDeviceStrategy("device:CPU:0")
258+
with strategy.scope():
259+
var = tf.Variable([1.0, 2.0])
260+
261+
opt = MovingAverage(
262+
tf.keras.optimizers.SGD(lr=2.0),
263+
sequential_update=sequential_update,
264+
average_decay=0.5,
265+
)
266+
267+
with strategy.scope():
268+
grads = tf.constant([0.1, 0.1])
269+
270+
def apply_gradients():
271+
opt.apply_gradients([(grads, var)])
272+
273+
strategy.run(apply_gradients)
274+
275+
np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
276+
ema_var = opt.get_slot(var, "average")
277+
if sequential_update:
278+
np.testing.assert_allclose(ema_var.read_value(), [0.9, 1.9])
279+
280+
opt.shadow_copy([var])
281+
282+
with strategy.scope():
283+
opt.swap_weights()
284+
285+
np.testing.assert_allclose(ema_var.read_value(), [0.8, 1.8])
286+
if sequential_update:
287+
np.testing.assert_allclose(var.read_value(), [0.9, 1.9])
288+
289+
with strategy.scope():
290+
opt.swap_weights()
291+
292+
np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
293+
if sequential_update:
294+
np.testing.assert_allclose(ema_var.read_value(), [0.9, 1.9])

0 commit comments

Comments
 (0)