Skip to content

Commit d9629a0

Browse files
committed
Add ability to swap weights to MovingAverage.
This patch makes it easier to swap the model weights and the MovingAverage weights before eval and swap them back after eval.
1 parent 90c190b commit d9629a0

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

tensorflow_addons/optimizers/moving_average.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,61 @@ def _create_slots(self, var_list):
124124
) # pylint: disable=protected-access
125125
for var in var_list:
126126
self.add_slot(var, "average", var.read_value())
127+
128+
def shadow_copy(self, model_weights):
129+
"""Creates shadow variables for the given model weights."""
130+
for var in model_weights:
131+
self.add_slot(var, "average", initializer="zeros")
132+
self._average_weights = [self.get_slot(var, "average") for var in model_weights]
133+
self._model_weights = model_weights
134+
135+
@property
136+
def has_shadow_copy(self):
137+
"""Whether this optimizer has created shadow variables."""
138+
return self._model_weights is not None
139+
140+
def swap_weights(self):
141+
"""Swap the average and moving weights.
142+
143+
This is a convenience method to allow one to evaluate the averaged weights
144+
at test time. Loads the weights stored in `self._average_weights` into the model,
145+
keeping a copy of the original model weights. Swapping twice will return
146+
the original weights.
147+
"""
148+
if tf.distribute.in_cross_replica_context():
149+
strategy = tf.distribute.get_strategy()
150+
return strategy.run(self._swap_weights, args=())
151+
else:
152+
raise ValueError(
153+
"Swapping weights must occur under a " "tf.distribute.Strategy"
154+
)
155+
156+
@tf.function
157+
def _swap_weights(self):
158+
def fn_0(a, b):
159+
a.assign_add(b)
160+
return a
161+
162+
def fn_1(b, a):
163+
b.assign(a - b)
164+
return b
165+
166+
def fn_2(a, b):
167+
a.assign_sub(b)
168+
return a
169+
170+
def swap(strategy, a, b):
171+
"""Swap `a` and `b` and mirror to all devices."""
172+
for a_element, b_element in zip(a, b):
173+
strategy.extended.update(
174+
a_element, fn_0, args=(b_element,)
175+
) # a = a + b
176+
strategy.extended.update(
177+
b_element, fn_1, args=(a_element,)
178+
) # b = a - b
179+
strategy.extended.update(
180+
a_element, fn_2, args=(b_element,)
181+
) # a = a - b
182+
183+
ctx = tf.distribute.get_replica_context()
184+
return ctx.merge_call(swap, args=(self._average_weights, self._model_weights,))

tensorflow_addons/optimizers/tests/moving_average_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,36 @@ def test_dynamic_decay():
213213

214214
ema_var0 = opt.get_slot(var0, "average")
215215
np.testing.assert_allclose(ema_var0.read_value(), [0.64, 1.64])
216+
217+
218+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
219+
@pytest.mark.with_device([tf.distribute.MirroredStrategy])
220+
def test_swap_weights(device):
221+
with device.scope():
222+
var = tf.Variable([1.0, 2.0])
223+
grads = tf.constant([0.1, 0.1])
224+
225+
opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5,)
226+
227+
@tf.function
228+
def apply_gradients():
229+
opt.apply_gradients([(grads, var)])
230+
231+
device.run(apply_gradients)
232+
233+
np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
234+
ema_var = opt.get_slot(var, "average")
235+
np.testing.assert_allclose(ema_var.read_value(), [0.85, 1.85])
236+
237+
with device.scope():
238+
opt.shadow_copy([var])
239+
opt.swap_weights()
240+
241+
np.testing.assert_allclose(ema_var.read_value(), [0.8, 1.8])
242+
np.testing.assert_allclose(var.read_value(), [0.85, 1.85])
243+
244+
with device.scope():
245+
opt.swap_weights()
246+
247+
np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
248+
np.testing.assert_allclose(ema_var.read_value(), [0.85, 1.85])

0 commit comments

Comments
 (0)