Skip to content

Commit 0f52f44

Browse files
authored
Pull request for fixing warm-starting device placement (tensorflow#17312) (tensorflow#17314)
* Update checkpoint_utils.py Fix device allocation bug for warm-starting op * Update checkpoint_utils_test.py Fix test
1 parent 72b5d12 commit 0f52f44

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

tensorflow/python/training/checkpoint_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,11 @@ def _set_checkpoint_initializer(variable,
289289
name: Name of the operation.
290290
"""
291291
base_type = variable.dtype.base_dtype
292-
with ops.colocate_with(variable):
292+
# Do not colocate with variable since RestoreV2 op only runs on CPU and
293+
# colocation will force variable (and other ops that colocate with variable)
294+
# to be on CPU as well. It is okay to place the variable's initializer op on
295+
# CPU since it will only be run once at the start.
296+
with ops.device(variable.device), ops.device("/cpu:0"):
293297
restore_op = io_ops.restore_v2(
294298
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
295299
variable._initializer_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access

tensorflow/python/training/checkpoint_utils_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ def testRestoreRunsOnSameDevice(self):
176176

177177
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
178178
{"useful_scope/": "useful_scope/"})
179-
self.assertEqual(my4._initializer_op.op.inputs[1].device, "/job:ps")
179+
# initializer runs on the same task but always on CPU.
180+
self.assertEqual(my4._initializer_op.op.inputs[1].device,
181+
"/job:ps/device:CPU:0")
180182

181183
def testInitFromRootCheckpoint(self):
182184
checkpoint_dir = self.get_temp_dir()

0 commit comments

Comments
 (0)