Skip to content

[WIP] Testutils run distributed fix #1209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ebd39eb
created testing file
hyang0129 Mar 2, 2020
e61a1f0
created run distributed test1
hyang0129 Mar 2, 2020
52dc73a
created more tests
hyang0129 Mar 2, 2020
c4a37de
fixed build file?
hyang0129 Mar 2, 2020
7bd6ac7
fixed typo in the training loop
hyang0129 Mar 2, 2020
9050b4f
updated test utils to no longer use experimental virtual device and i…
hyang0129 Mar 2, 2020
916b2b7
created function to get or create virtual devices and added error exp…
hyang0129 Mar 2, 2020
38e041c
added test to run first
hyang0129 Mar 2, 2020
55dd75a
trying to figure out correct virtual device initialization order
hyang0129 Mar 3, 2020
bd3c26e
trying to figure out correct virtual device initialization order
hyang0129 Mar 3, 2020
d78de6a
added logging to identify when devices first get initialized
hyang0129 Mar 3, 2020
b185251
identified that list logical devices actually initalizes them
hyang0129 Mar 4, 2020
8cb7f5e
added test to confirm correct error is raised with correct message
hyang0129 Mar 4, 2020
613eac5
moved the wrapper to wrap within the scope of with.selfAssertRaises
hyang0129 Mar 4, 2020
62134ec
func was missing reference to self
hyang0129 Mar 4, 2020
ac07075
fixed exception message reference. now calls str() to get the string
hyang0129 Mar 4, 2020
d6fc38f
no longer checking message equals. this is because I don't want to de…
hyang0129 Mar 4, 2020
b78f8be
removed unused variable
hyang0129 Mar 4, 2020
f792d44
minor comment fixes
hyang0129 Mar 4, 2020
87293eb
trying various configurations to narrow down the issue.
hyang0129 Mar 4, 2020
6928aeb
changed to assert on test_utils
hyang0129 Mar 10, 2020
a1fe29e
fixed assertion typo, should be <= isntead of <
hyang0129 Mar 10, 2020
58aa7a6
fixed test for confirming that error gets raised
hyang0129 Mar 10, 2020
e83ad62
fixed test for confirming that error gets raised by no longering pass…
hyang0129 Mar 10, 2020
fbd54d3
trying to figure out how to properly wrap this method so that it corr…
hyang0129 Mar 10, 2020
5a61612
assert should have been using ==
hyang0129 Mar 10, 2020
c88593e
Merge branch 'master' into pr/hyang0129/1209
seanpmorgan Mar 11, 2020
cba4cef
Merge branch 'master' into hyang0129_testutils_dist_test
gabrieldemarmiesse Mar 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tensorflow_addons/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,15 @@ py_test(
":utils",
],
)

py_test(
name = "test_utils_test",
size = "small",
srcs = [
"test_utils_test.py",
],
main = "test_utils_test.py",
deps = [
":utils",
],
)
70 changes: 63 additions & 7 deletions tensorflow_addons/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
import time
import unittest

import logging
import tensorflow as tf

# TODO: find public API alternative to these
Expand Down Expand Up @@ -80,17 +80,73 @@ def create_virtual_devices(
if device_type == "CPU":
memory_limit_per_device = None

tf.config.experimental.set_virtual_device_configuration(
tf.config.set_logical_device_configuration(
physical_devices[0],
[
tf.config.experimental.VirtualDeviceConfiguration(
memory_limit=memory_limit_per_device
)
tf.config.LogicalDeviceConfiguration(memory_limit=memory_limit_per_device)
for _ in range(num_devices)
],
)

return tf.config.experimental.list_logical_devices(device_type)
return tf.config.list_logical_devices(device_type)


def create_or_get_logical_devices(
num_devices, force_device=None, memory_limit_per_device=1024
):
"""Virtualize a the physical device into logical devices or get devices if virtualization
has already occurred.

Args:
num_devices: The number of virtual devices needed.
force_device: 'CPU'/'GPU'. Defaults to None, where the
devices is selected based on the system.
memory_limit_per_device: Specify memory for each
virtual GPU. Only for GPUs.

Returns:
logical_devices_out: A list of logical devices which can be passed to
tf.distribute.MirroredStrategy()
"""
if force_device is None:
device_type = (
"GPU" if len(tf.config.list_physical_devices("GPU")) > 0 else "CPU"
)
else:
assert force_device in ["CPU", "GPU"]
device_type = force_device

physical_devices = tf.config.list_physical_devices(device_type)

# check the logical device configuration. Do not use list device because that actually initializes devices.
logical_config = tf.config.get_logical_device_configuration(physical_devices[0])

# explicitly confirm that we have no device configuration.
if logical_config is None:
# create devices
logical_devices = create_virtual_devices(
num_devices, force_device, memory_limit_per_device
)
logging.info("%i logical devices initialized" % num_devices)
else:
# if we have a configuration, then get the logical devices.
logical_devices = tf.config.list_logical_devices(device_type)

# take at most num_devices number of logical devices.
logical_devices_out = logical_devices[:num_devices]

# confirm that we are returning the correct number of logical devices.
assert (
len(logical_devices_out) == num_devices
), """%i logical devices have been initialized at an earlier stage,
but the current request is for %i logical devices. Please initialize more logical devices at the earlier stage.
You are seeing this error because you cannot modify logical devices after initialization.
""" % (
len(logical_devices),
num_devices,
)

return logical_devices_out


def run_all_distributed(num_devices):
Expand Down Expand Up @@ -119,7 +175,7 @@ def decorator(f):
)

def decorated(self, *args, **kwargs):
logical_devices = create_virtual_devices(num_devices)
logical_devices = create_or_get_logical_devices(num_devices)
strategy = tf.distribute.MirroredStrategy(logical_devices)
with strategy.scope():
f(self, *args, **kwargs)
Expand Down
86 changes: 86 additions & 0 deletions tensorflow_addons/utils/test_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for test utils."""

import tensorflow as tf
import numpy as np
from tensorflow_addons.utils import test_utils


def _train_something():
# run a simple training loop to confirm that run distributed works.

model = tf.keras.models.Sequential()

model.add(tf.keras.layers.Dense(1, input_shape=(1,)))

model.compile(loss="mse", optimizer="sgd")

x = np.zeros(shape=(32, 1))
y = np.zeros(shape=(32, 1))

model.fit(x, y, batch_size=2, epochs=2)


class TestA(tf.test.TestCase):
# hopefully this test will run first so things init properly.
@test_utils.run_distributed(4)
def test_training_dist(self):
_train_something()


class TestUtilsTestMixed(tf.test.TestCase):
# we should be able to run some tests that are distributed and some that are not distributed.
def test_training(self):
_train_something()

@test_utils.run_distributed(4)
def test_training_dist_many(self):
_train_something()

@test_utils.run_distributed(2)
def test_training_dist_few(self):
_train_something()

@test_utils.run_in_graph_and_eager_modes
def test_training_graph_eager(self):
_train_something()

@test_utils.run_in_graph_and_eager_modes
@test_utils.run_distributed(2)
def test_training_graph_eager_dist(self):
_train_something()

def test_train_dist_too_many(self):
with self.assertRaises(AssertionError):
# create a function that is wrapped. if we wrapped test_train_dist_too_many, the error is raised
# outside of the scope of self.assertRaises.
func = test_utils.run_distributed(10)(self.__class__.test_training)
func(self)
# this should raise a runtime error.


@test_utils.run_all_distributed(3)
class TestUtilsTest(tf.test.TestCase):
# test the class wrapper
def test_training(self):
_train_something()

def test_training_again(self):
_train_something()


if __name__ == "__main__":
tf.test.main()