Skip to content

Commit d2e24b6

Browse files
yifeifgunan
authored andcommitted
Don't assign device for the keras part of _saved_first_checkpoint. Fix tensorflow#14504. (tensorflow#17231)
PiperOrigin-RevId: 186526175
1 parent 0f52f44 commit d2e24b6

File tree

2 files changed

+38
-13
lines changed

2 files changed

+38
-13
lines changed

tensorflow/python/keras/_impl/keras/estimator.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -221,18 +221,18 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects,
221221
Returns:
222222
The model_fn for a keras Estimator.
223223
"""
224-
with ops.Graph().as_default() as g, g.device(estimator._device_fn):
225-
random_seed.set_random_seed(estimator.config.tf_random_seed)
226-
training_util.create_global_step()
227-
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
228-
custom_objects)
229-
230-
if isinstance(model, models.Sequential):
231-
model = model.model
232-
# Load weights and save to checkpoint if there is no checkpoint
233-
latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
234-
if not latest_path:
235-
with session.Session() as sess:
224+
# Load weights and save to checkpoint if there is no checkpoint
225+
latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
226+
if not latest_path:
227+
with ops.Graph().as_default():
228+
random_seed.set_random_seed(estimator.config.tf_random_seed)
229+
training_util.create_global_step()
230+
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
231+
custom_objects)
232+
if isinstance(model, models.Sequential):
233+
model = model.model
234+
# save to checkpoint
235+
with session.Session(config=estimator._session_config) as sess:
236236
model.set_weights(keras_weights)
237237
# Make update ops and initialize all variables.
238238
if not model.train_function:

tensorflow/python/keras/_impl/keras/estimator_test.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
import json
2021
from math import log10
2122
import os
2223
import tempfile
@@ -62,7 +63,7 @@ def simple_functional_model():
6263
return model
6364

6465

65-
def get_resource_for_simple_model(is_sequential, is_evaluate):
66+
def get_resource_for_simple_model(is_sequential=True, is_evaluate=False):
6667
model = simple_sequential_model(
6768
) if is_sequential else simple_functional_model()
6869
if is_sequential:
@@ -352,6 +353,30 @@ def test_custom_objects(self):
352353
model_dir=tempfile.mkdtemp(dir=self._base_dir),
353354
custom_objects=custom_objects)
354355

356+
def test_tf_config(self):
357+
keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
358+
keras_model.compile(
359+
loss='categorical_crossentropy',
360+
optimizer='rmsprop',
361+
metrics=['mse', keras.metrics.categorical_accuracy])
362+
363+
tf_config = json.dumps({
364+
'cluster': {
365+
run_config_lib.TaskType.PS: ['localhost:1234'],
366+
run_config_lib.TaskType.WORKER: ['localhost:1236'],
367+
run_config_lib.TaskType.MASTER: ['localhost:1238']
368+
},
369+
'task': {
370+
'type': run_config_lib.TaskType.MASTER,
371+
'index': 0
372+
}
373+
})
374+
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
375+
with self.test_session():
376+
keras.estimator.model_to_estimator(
377+
keras_model=keras_model,
378+
model_dir=tempfile.mkdtemp(dir=self._base_dir))
379+
355380

356381
if __name__ == '__main__':
357382
test.main()

0 commit comments

Comments
 (0)