diff --git a/tensorflow_addons/utils/BUILD b/tensorflow_addons/utils/BUILD index 2cca7537af..166ed25990 100644 --- a/tensorflow_addons/utils/BUILD +++ b/tensorflow_addons/utils/BUILD @@ -25,3 +25,15 @@ py_test( ":utils", ], ) + +py_test( + name = "test_utils_test", + size = "small", + srcs = [ + "test_utils_test.py", + ], + main = "test_utils_test.py", + deps = [ + ":utils", + ], +) diff --git a/tensorflow_addons/utils/test_utils.py b/tensorflow_addons/utils/test_utils.py index 2528a2c6d7..5e95957fda 100644 --- a/tensorflow_addons/utils/test_utils.py +++ b/tensorflow_addons/utils/test_utils.py @@ -18,7 +18,7 @@ import inspect import time import unittest - +import logging import tensorflow as tf # TODO: find public API alternative to these @@ -80,17 +80,73 @@ def create_virtual_devices( if device_type == "CPU": memory_limit_per_device = None - tf.config.experimental.set_virtual_device_configuration( + tf.config.set_logical_device_configuration( physical_devices[0], [ - tf.config.experimental.VirtualDeviceConfiguration( - memory_limit=memory_limit_per_device - ) + tf.config.LogicalDeviceConfiguration(memory_limit=memory_limit_per_device) for _ in range(num_devices) ], ) - return tf.config.experimental.list_logical_devices(device_type) + return tf.config.list_logical_devices(device_type) + + +def create_or_get_logical_devices( + num_devices, force_device=None, memory_limit_per_device=1024 +): + """Virtualize a the physical device into logical devices or get devices if virtualization + has already occurred. + + Args: + num_devices: The number of virtual devices needed. + force_device: 'CPU'/'GPU'. Defaults to None, where the + devices is selected based on the system. + memory_limit_per_device: Specify memory for each + virtual GPU. Only for GPUs. + + Returns: + logical_devices_out: A list of logical devices which can be passed to + tf.distribute.MirroredStrategy() + """ + if force_device is None: + device_type = ( + "GPU" if len(tf.config.list_physical_devices("GPU")) > 0 else "CPU" + ) + else: + assert force_device in ["CPU", "GPU"] + device_type = force_device + + physical_devices = tf.config.list_physical_devices(device_type) + + # check the logical device configuration. Do not use list device because that actually initializes devices. + logical_config = tf.config.get_logical_device_configuration(physical_devices[0]) + + # explicitly confirm that we have no device configuration. + if logical_config is None: + # create devices + logical_devices = create_virtual_devices( + num_devices, force_device, memory_limit_per_device + ) + logging.info("%i logical devices initialized" % num_devices) + else: + # if we have a configuration, then get the logical devices. + logical_devices = tf.config.list_logical_devices(device_type) + + # take at most num_devices number of logical devices. + logical_devices_out = logical_devices[:num_devices] + + # confirm that we are returning the correct number of logical devices. + assert ( + len(logical_devices_out) == num_devices + ), """%i logical devices have been initialized at an earlier stage, + but the current request is for %i logical devices. Please initialize more logical devices at the earlier stage. + You are seeing this error because you cannot modify logical devices after initialization. + """ % ( + len(logical_devices), + num_devices, + ) + + return logical_devices_out def run_all_distributed(num_devices): @@ -119,7 +175,7 @@ def decorator(f): ) def decorated(self, *args, **kwargs): - logical_devices = create_virtual_devices(num_devices) + logical_devices = create_or_get_logical_devices(num_devices) strategy = tf.distribute.MirroredStrategy(logical_devices) with strategy.scope(): f(self, *args, **kwargs) diff --git a/tensorflow_addons/utils/test_utils_test.py b/tensorflow_addons/utils/test_utils_test.py new file mode 100644 index 0000000000..f8393be29b --- /dev/null +++ b/tensorflow_addons/utils/test_utils_test.py @@ -0,0 +1,86 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for test utils.""" + +import tensorflow as tf +import numpy as np +from tensorflow_addons.utils import test_utils + + +def _train_something(): + # run a simple training loop to confirm that run distributed works. + + model = tf.keras.models.Sequential() + + model.add(tf.keras.layers.Dense(1, input_shape=(1,))) + + model.compile(loss="mse", optimizer="sgd") + + x = np.zeros(shape=(32, 1)) + y = np.zeros(shape=(32, 1)) + + model.fit(x, y, batch_size=2, epochs=2) + + +class TestA(tf.test.TestCase): + # hopefully this test will run first so things init properly. + @test_utils.run_distributed(4) + def test_training_dist(self): + _train_something() + + +class TestUtilsTestMixed(tf.test.TestCase): + # we should be able to run some tests that are distributed and some that are not distributed. + def test_training(self): + _train_something() + + @test_utils.run_distributed(4) + def test_training_dist_many(self): + _train_something() + + @test_utils.run_distributed(2) + def test_training_dist_few(self): + _train_something() + + @test_utils.run_in_graph_and_eager_modes + def test_training_graph_eager(self): + _train_something() + + @test_utils.run_in_graph_and_eager_modes + @test_utils.run_distributed(2) + def test_training_graph_eager_dist(self): + _train_something() + + def test_train_dist_too_many(self): + with self.assertRaises(AssertionError): + # create a function that is wrapped. if we wrapped test_train_dist_too_many, the error is raised + # outside of the scope of self.assertRaises. + func = test_utils.run_distributed(10)(self.__class__.test_training) + func(self) + # this should raise a runtime error. + + +@test_utils.run_all_distributed(3) +class TestUtilsTest(tf.test.TestCase): + # test the class wrapper + def test_training(self): + _train_something() + + def test_training_again(self): + _train_something() + + +if __name__ == "__main__": + tf.test.main()