From 0a57c39356e408a9d9bb87b40910e78ebe0de283 Mon Sep 17 00:00:00 2001 From: bhack Date: Thu, 27 Aug 2020 16:15:06 +0200 Subject: [PATCH 1/4] Update lookahead.py Inital fix of https://github.com/tensorflow/addons/issues/2094 https://github.com/tensorflow/addons/pull/2102 --- tensorflow_addons/optimizers/lookahead.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 8f96dc9c62..48544c23b2 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -80,6 +80,7 @@ def __init__( self._set_hyper("sync_period", sync_period) self._set_hyper("slow_step_size", slow_step_size) self._initialized = False + self._track_trackable(self._optimizer, 'base_optimizer') def _create_slots(self, var_list): self._optimizer._create_slots( From 40fc5131acd739ca57f3580fc666994f7674f6e5 Mon Sep 17 00:00:00 2001 From: bhack Date: Thu, 27 Aug 2020 16:22:56 +0200 Subject: [PATCH 2/4] Fix linting --- tensorflow_addons/optimizers/lookahead.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 48544c23b2..1cda031069 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -80,7 +80,7 @@ def __init__( self._set_hyper("sync_period", sync_period) self._set_hyper("slow_step_size", slow_step_size) self._initialized = False - self._track_trackable(self._optimizer, 'base_optimizer') + self._track_trackable(self._optimizer, "base_optimizer") def _create_slots(self, var_list): self._optimizer._create_slots( From 37049dc31ee7f80422dc246d0f8017a82dab785b Mon Sep 17 00:00:00 2001 From: bhack Date: Thu, 27 Aug 2020 16:47:17 +0200 Subject: [PATCH 3/4] Resolve name conflict with mixed prexision --- tensorflow_addons/optimizers/lookahead.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/optimizers/lookahead.py b/tensorflow_addons/optimizers/lookahead.py index 1cda031069..a52bc2df76 100644 --- a/tensorflow_addons/optimizers/lookahead.py +++ b/tensorflow_addons/optimizers/lookahead.py @@ -80,7 +80,7 @@ def __init__( self._set_hyper("sync_period", sync_period) self._set_hyper("slow_step_size", slow_step_size) self._initialized = False - self._track_trackable(self._optimizer, "base_optimizer") + self._track_trackable(self._optimizer, "lh_base_optimizer") def _create_slots(self, var_list): self._optimizer._create_slots( From 02fc62734ba8a5108361a05073958b4d992aa5c2 Mon Sep 17 00:00:00 2001 From: bhack Date: Sun, 30 Aug 2020 12:49:48 +0200 Subject: [PATCH 4/4] Track baseline optimizer in avg --- tensorflow_addons/optimizers/average_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_addons/optimizers/average_wrapper.py b/tensorflow_addons/optimizers/average_wrapper.py index 86f815bb4d..d510116a02 100644 --- a/tensorflow_addons/optimizers/average_wrapper.py +++ b/tensorflow_addons/optimizers/average_wrapper.py @@ -46,6 +46,7 @@ def __init__( raise TypeError("sequential_update must be of bool type") self._optimizer = optimizer + self._track_trackable(self._optimizer, "awg_optimizer") if sequential_update is not None: warnings.warn(