Skip to content

Commit b035a22

Browse files
Allen Wangtensorflower-gardener
Allen Wang
authored andcommitted
Internal change
PiperOrigin-RevId: 319058156
1 parent df89d3e commit b035a22

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

official/vision/image_classification/classifier_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ def train_and_eval(
339339
optimizer = optimizer_factory.build_optimizer(
340340
optimizer_name=params.model.optimizer.name,
341341
base_learning_rate=learning_rate,
342-
params=params.model.optimizer.as_dict())
342+
params=params.model.optimizer.as_dict(),
343+
model=model)
343344

344345
metrics_map = _get_metrics(one_hot)
345346
metrics = [metrics_map[metric] for metric in params.train.metrics]

official/vision/image_classification/optimizer_factory.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
# from __future__ import google_type_annotations
1919
from __future__ import print_function
2020

21+
from typing import Any, Dict, Text, List
22+
2123
from absl import logging
2224
import tensorflow as tf
2325
import tensorflow_addons as tfa
2426

25-
from typing import Any, Dict, Text, List
2627
from official.vision.image_classification import learning_rate
2728
from official.vision.image_classification.configs import base_configs
2829

@@ -250,7 +251,8 @@ def from_config(cls, config, custom_objects=None):
250251
def build_optimizer(
251252
optimizer_name: Text,
252253
base_learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule,
253-
params: Dict[Text, Any]):
254+
params: Dict[Text, Any],
255+
model: tf.keras.Model = None):
254256
"""Build the optimizer based on name.
255257
256258
Args:
@@ -261,6 +263,8 @@ def build_optimizer(
261263
params: String -> Any dictionary representing the optimizer params.
262264
This should contain optimizer specific parameters such as
263265
`base_learning_rate`, `decay`, etc.
266+
model: The `tf.keras.Model`. This is used for the shadow copy if using
267+
`MovingAverage`.
264268
265269
Returns:
266270
A tf.keras.Optimizer.
@@ -322,10 +326,13 @@ def build_optimizer(
322326
# Moving average should be applied last, as it's applied at test time
323327
moving_average_decay = params.get('moving_average_decay', 0.)
324328
if moving_average_decay is not None and moving_average_decay > 0.:
329+
if model is None:
330+
raise ValueError('`model` must be provided if using `MovingAverage`.')
325331
logging.info('Including moving average decay.')
326332
optimizer = MovingAverage(
327-
optimizer,
333+
optimizer=optimizer,
328334
average_decay=moving_average_decay)
335+
optimizer.shadow_copy(model)
329336
return optimizer
330337

331338

official/vision/image_classification/optimizer_factory_test.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,21 @@
1919
# from __future__ import google_type_annotations
2020
from __future__ import print_function
2121

22-
import tensorflow as tf
23-
2422
from absl.testing import parameterized
23+
24+
import tensorflow as tf
2525
from official.vision.image_classification import optimizer_factory
2626
from official.vision.image_classification.configs import base_configs
2727

2828

2929
class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
3030

31+
def build_toy_model(self) -> tf.keras.Model:
32+
"""Creates a toy `tf.Keras.Model`."""
33+
model = tf.keras.Sequential()
34+
model.add(tf.keras.layers.Dense(1, input_shape=(1,)))
35+
return model
36+
3137
@parameterized.named_parameters(
3238
('sgd', 'sgd', 0., False),
3339
('momentum', 'momentum', 0., False),
@@ -40,6 +46,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
4046
('rmsprop_ema', 'rmsprop', 0.999, False))
4147
def test_optimizer(self, optimizer_name, moving_average_decay, lookahead):
4248
"""Smoke test to be sure no syntax errors."""
49+
model = self.build_toy_model()
4350
params = {
4451
'learning_rate': 0.001,
4552
'rho': 0.09,
@@ -51,7 +58,8 @@ def test_optimizer(self, optimizer_name, moving_average_decay, lookahead):
5158
optimizer = optimizer_factory.build_optimizer(
5259
optimizer_name=optimizer_name,
5360
base_learning_rate=params['learning_rate'],
54-
params=params)
61+
params=params,
62+
model=model)
5563
self.assertTrue(issubclass(type(optimizer), tf.keras.optimizers.Optimizer))
5664

5765
def test_unknown_optimizer(self):

0 commit comments

Comments
 (0)