Skip to content

MovingAverage: add dynamic decay and swap weights #1726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion tensorflow_addons/optimizers/moving_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
sequential_update: bool = True,
average_decay: types.FloatTensorLike = 0.99,
num_updates: Optional[str] = None,
start_step: int = 0,
dynamic_decay: bool = False,
name: str = "MovingAverage",
**kwargs
):
Expand All @@ -64,6 +66,10 @@ def __init__(
of trained variables.
num_updates: Optional count of the number of updates applied to
variables.
start_step: int. What step to start the moving average.
dynamic_decay: bool. Whether to change the decay based on the number
of optimizer updates. Decay will start at 0.1 and gradually
increase up to `average_decay` after each optimizer update.
name: Optional name for the operations created when applying
gradients. Defaults to "MovingAverage".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
Expand All @@ -82,15 +88,32 @@ def __init__(
)

self._set_hyper("average_decay", average_decay)
self._start_step = start_step
self._dynamic_decay = dynamic_decay

@tf.function
def _get_decay(self, step: tf.Tensor):
average_decay = self._get_hyper("average_decay", tf.dtypes.float32)

step = tf.cast(step, tf.float32)
if step < self._start_step:
return tf.constant(0.0, tf.float32)
elif self._dynamic_decay:
step_count = step - self._start_step
return tf.minimum(average_decay, (1.0 + step_count) / (10.0 + step_count))
else:
return average_decay

def average_op(self, var, average_var):
decay = self._get_hyper("average_decay", tf.dtypes.float32)
decay = self._get_decay(self._optimizer.iterations)
return assign_moving_average(average_var, var, decay, False)

def get_config(self):
config = {
"average_decay": self._serialize_hyperparameter("average_decay"),
"num_updates": self._num_updates,
"start_step": self._start_step,
"dynamic_decay": self._dynamic_decay,
}
base_config = super().get_config()
return {**base_config, **config}
Expand All @@ -101,3 +124,61 @@ def _create_slots(self, var_list):
) # pylint: disable=protected-access
for var in var_list:
self.add_slot(var, "average", var.read_value())

def shadow_copy(self, model_weights):
"""Creates shadow variables for the given model weights."""
for var in model_weights:
self.add_slot(var, "average", initializer="zeros")
self._average_weights = [self.get_slot(var, "average") for var in model_weights]
self._model_weights = model_weights

@property
def has_shadow_copy(self):
"""Whether this optimizer has created shadow variables."""
return self._model_weights is not None

def swap_weights(self):
"""Swap the average and moving weights.

This is a convenience method to allow one to evaluate the averaged weights
at test time. Loads the weights stored in `self._average_weights` into the model,
keeping a copy of the original model weights. Swapping twice will return
the original weights.
"""
if tf.distribute.in_cross_replica_context():
strategy = tf.distribute.get_strategy()
return strategy.run(self._swap_weights, args=())
else:
raise ValueError(
"Swapping weights must occur under a " "tf.distribute.Strategy"
)

@tf.function
def _swap_weights(self):
def fn_0(a, b):
a.assign_add(b)
return a

def fn_1(b, a):
b.assign(a - b)
return b

def fn_2(a, b):
a.assign_sub(b)
return a

def swap(strategy, a, b):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this trick for swapping.

"""Swap `a` and `b` and mirror to all devices."""
for a_element, b_element in zip(a, b):
strategy.extended.update(
a_element, fn_0, args=(b_element,)
) # a = a + b
strategy.extended.update(
b_element, fn_1, args=(a_element,)
) # b = a - b
strategy.extended.update(
a_element, fn_2, args=(b_element,)
) # a = a - b

ctx = tf.distribute.get_replica_context()
return ctx.merge_call(swap, args=(self._average_weights, self._model_weights,))
85 changes: 83 additions & 2 deletions tensorflow_addons/optimizers/tests/moving_average_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,15 @@ def test_optimizer_string():

def test_config():
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
opt = MovingAverage(sgd_opt, average_decay=0.5, num_updates=None)
opt = MovingAverage(
sgd_opt, average_decay=0.5, num_updates=None, start_step=5, dynamic_decay=True,
)
config = opt.get_config()

assert config["average_decay"] == 0.5
assert config["num_updates"] is None
assert config["start_step"] == 5
assert config["dynamic_decay"] is True

new_opt = MovingAverage.from_config(config)
old_sgd_config = opt._optimizer.get_config()
Expand Down Expand Up @@ -161,7 +165,84 @@ def test_fit_simple_linear_model():

def test_serialization():
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
optimizer = MovingAverage(sgd_opt, average_decay=0.5, num_updates=None)
optimizer = MovingAverage(
sgd_opt, average_decay=0.5, num_updates=None, start_step=5, dynamic_decay=True,
)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_start_step():
var0 = tf.Variable([1.0, 2.0])
grads0 = tf.constant([0.1, 0.1])
grads_and_vars = [(grads0, var0)]

opt = MovingAverage(
tf.keras.optimizers.SGD(lr=1.0), average_decay=0.5, start_step=1,
)

opt.apply_gradients(grads_and_vars)

np.testing.assert_allclose(var0.read_value(), [0.9, 1.9])

ema_var0 = opt.get_slot(var0, "average")

opt.apply_gradients(grads_and_vars)

np.testing.assert_allclose(var0.read_value(), [0.8, 1.8])

np.testing.assert_allclose(ema_var0.read_value(), [0.85, 1.85])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_dynamic_decay():
var0 = tf.Variable([1.0, 2.0])
grads0 = tf.constant([0.1, 0.1])
grads_and_vars = [(grads0, var0)]

opt = MovingAverage(
tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5, dynamic_decay=True,
)

opt.apply_gradients(grads_and_vars)
opt.apply_gradients(grads_and_vars)

np.testing.assert_allclose(var0.read_value(), [0.6, 1.6])

ema_var0 = opt.get_slot(var0, "average")
np.testing.assert_allclose(ema_var0.read_value(), [0.64, 1.64])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.with_device([tf.distribute.MirroredStrategy])
def test_swap_weights(device):
with device.scope():
var = tf.Variable([1.0, 2.0])
grads = tf.constant([0.1, 0.1])

opt = MovingAverage(tf.keras.optimizers.SGD(lr=2.0), average_decay=0.5,)

@tf.function
def apply_gradients():
opt.apply_gradients([(grads, var)])

device.run(apply_gradients)

np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
ema_var = opt.get_slot(var, "average")
np.testing.assert_allclose(ema_var.read_value(), [0.85, 1.85])

with device.scope():
opt.shadow_copy([var])
opt.swap_weights()

np.testing.assert_allclose(ema_var.read_value(), [0.8, 1.8])
np.testing.assert_allclose(var.read_value(), [0.85, 1.85])

with device.scope():
opt.swap_weights()

np.testing.assert_allclose(var.read_value(), [0.8, 1.8])
np.testing.assert_allclose(ema_var.read_value(), [0.85, 1.85])