From 1ac966f8d92e4e7008554f6de53d56123fe0e782 Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 6 Mar 2020 11:26:17 +0100 Subject: [PATCH 1/3] Added file for metrics tests. --- tensorflow_addons/metrics/metrics_test.py | 88 +++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 tensorflow_addons/metrics/metrics_test.py diff --git a/tensorflow_addons/metrics/metrics_test.py b/tensorflow_addons/metrics/metrics_test.py new file mode 100644 index 0000000000..2b7a8cd506 --- /dev/null +++ b/tensorflow_addons/metrics/metrics_test.py @@ -0,0 +1,88 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import unittest + +import tensorflow as tf + + + +class MatthewsCorrelationCoefficientTest(tf.test.TestCase): + def test_config(self): + # mcc object + mcc1 = MatthewsCorrelationCoefficient(num_classes=1) + self.assertEqual(mcc1.num_classes, 1) + self.assertEqual(mcc1.dtype, tf.float32) + # check configure + mcc2 = MatthewsCorrelationCoefficient.from_config(mcc1.get_config()) + self.assertEqual(mcc2.num_classes, 1) + self.assertEqual(mcc2.dtype, tf.float32) + + def initialize_vars(self, n_classes): + mcc = MatthewsCorrelationCoefficient(num_classes=n_classes) + self.evaluate(tf.compat.v1.variables_initializer(mcc.variables)) + return mcc + + def update_obj_states(self, obj, gt_label, preds): + update_op = obj.update_state(gt_label, preds) + self.evaluate(update_op) + + def check_results(self, obj, value): + self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5) + + def test_binary_classes(self): + gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32) + preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32) + # Initialize + mcc = self.initialize_vars(n_classes=1) + # Update + self.update_obj_states(mcc, gt_label, preds) + # Check results + self.check_results(mcc, [-0.33333334]) + + def test_multiple_classes(self): + gt_label = tf.constant( + [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]], + dtype=tf.float32, + ) + preds = tf.constant( + [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]], + dtype=tf.float32, + ) + # Initialize + mcc = self.initialize_vars(n_classes=3) + # Update + self.update_obj_states(mcc, gt_label, preds) + # Check results + self.check_results(mcc, [-0.33333334, 1.0, 0.57735026]) + + # Keras model API check + def test_keras_model(self): + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(64, activation="relu")) + model.add(tf.keras.layers.Dense(64, activation="relu")) + model.add(tf.keras.layers.Dense(1, activation="softmax")) + mcc = MatthewsCorrelationCoefficient(num_classes=1) + model.compile( + optimizer="Adam", loss="binary_crossentropy", metrics=["accuracy", mcc] + ) + # data preparation + data = np.random.random((10, 1)) + labels = np.random.random((10, 1)) + labels = np.where(labels > 0.5, 1.0, 0.0) + model.fit(data, labels, epochs=1, batch_size=32, verbose=0) + + +if __name__ == "__main__": + unittest From f5af081885a6e19c461184856f42e1610765b69e Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 6 Mar 2020 12:12:04 +0100 Subject: [PATCH 2/3] Added a test for the metrics to check the signature. --- tensorflow_addons/metrics/BUILD | 12 +++ tensorflow_addons/metrics/metrics_test.py | 83 ++++--------------- .../metrics/multilabel_confusion_matrix.py | 10 ++- tensorflow_addons/metrics/r_square.py | 9 +- 4 files changed, 47 insertions(+), 67 deletions(-) diff --git a/tensorflow_addons/metrics/BUILD b/tensorflow_addons/metrics/BUILD index 4f0f06fd8c..0b43c74275 100644 --- a/tensorflow_addons/metrics/BUILD +++ b/tensorflow_addons/metrics/BUILD @@ -90,3 +90,15 @@ py_test( ":metrics", ], ) + +py_test( + name = "metrics_test", + size = "small", + srcs = [ + "metrics_test.py", + ], + main = "metrics_test.py", + deps = [ + ":metrics", + ], +) diff --git a/tensorflow_addons/metrics/metrics_test.py b/tensorflow_addons/metrics/metrics_test.py index 2b7a8cd506..407297d551 100644 --- a/tensorflow_addons/metrics/metrics_test.py +++ b/tensorflow_addons/metrics/metrics_test.py @@ -13,76 +13,29 @@ # limitations under the License. # ============================================================================== import unittest +import inspect -import tensorflow as tf +from tensorflow.keras.metrics import Metric +from tensorflow_addons import metrics +class MetricsTests(unittest.TestCase): + def test_update_state_signature(self): + for name, obj in inspect.getmembers(metrics): + if inspect.isclass(obj) and issubclass(obj, Metric): + check_update_state_signature(obj) -class MatthewsCorrelationCoefficientTest(tf.test.TestCase): - def test_config(self): - # mcc object - mcc1 = MatthewsCorrelationCoefficient(num_classes=1) - self.assertEqual(mcc1.num_classes, 1) - self.assertEqual(mcc1.dtype, tf.float32) - # check configure - mcc2 = MatthewsCorrelationCoefficient.from_config(mcc1.get_config()) - self.assertEqual(mcc2.num_classes, 1) - self.assertEqual(mcc2.dtype, tf.float32) - def initialize_vars(self, n_classes): - mcc = MatthewsCorrelationCoefficient(num_classes=n_classes) - self.evaluate(tf.compat.v1.variables_initializer(mcc.variables)) - return mcc - - def update_obj_states(self, obj, gt_label, preds): - update_op = obj.update_state(gt_label, preds) - self.evaluate(update_op) - - def check_results(self, obj, value): - self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5) - - def test_binary_classes(self): - gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32) - preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32) - # Initialize - mcc = self.initialize_vars(n_classes=1) - # Update - self.update_obj_states(mcc, gt_label, preds) - # Check results - self.check_results(mcc, [-0.33333334]) - - def test_multiple_classes(self): - gt_label = tf.constant( - [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]], - dtype=tf.float32, - ) - preds = tf.constant( - [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]], - dtype=tf.float32, - ) - # Initialize - mcc = self.initialize_vars(n_classes=3) - # Update - self.update_obj_states(mcc, gt_label, preds) - # Check results - self.check_results(mcc, [-0.33333334, 1.0, 0.57735026]) - - # Keras model API check - def test_keras_model(self): - model = tf.keras.Sequential() - model.add(tf.keras.layers.Dense(64, activation="relu")) - model.add(tf.keras.layers.Dense(64, activation="relu")) - model.add(tf.keras.layers.Dense(1, activation="softmax")) - mcc = MatthewsCorrelationCoefficient(num_classes=1) - model.compile( - optimizer="Adam", loss="binary_crossentropy", metrics=["accuracy", mcc] - ) - # data preparation - data = np.random.random((10, 1)) - labels = np.random.random((10, 1)) - labels = np.where(labels > 0.5, 1.0, 0.0) - model.fit(data, labels, epochs=1, batch_size=32, verbose=0) +def check_update_state_signature(metric_class): + update_state_signature = inspect.signature(metric_class.update_state) + for expected_parameter in ["y_true", "y_pred", "sample_weight"]: + if expected_parameter not in update_state_signature.parameters.keys(): + raise ValueError( + "Class {} is missing the parameter {} in the `update_state` " + "method. If the method doesn't use this argument, declare " + "it anyway and raise a UserWarning if it is " + "not None.".format(metric_class.__name__, expected_parameter)) if __name__ == "__main__": - unittest + unittest.main() diff --git a/tensorflow_addons/metrics/multilabel_confusion_matrix.py b/tensorflow_addons/metrics/multilabel_confusion_matrix.py index 357327b985..5694c52681 100644 --- a/tensorflow_addons/metrics/multilabel_confusion_matrix.py +++ b/tensorflow_addons/metrics/multilabel_confusion_matrix.py @@ -14,6 +14,8 @@ # ============================================================================== """Implements Multi-label confusion matrix scores.""" +import warnings + import tensorflow as tf from tensorflow.keras.metrics import Metric import numpy as np @@ -104,7 +106,13 @@ def __init__( dtype=self.dtype, ) - def update_state(self, y_true, y_pred): + def update_state(self, y_true, y_pred, sample_weight=None): + if sample_weight is not None: + warnings.warn( + "`sample_weight` is not None. Be aware that MultiLabelConfusionMatrix " + "does not take `sample_weight` into account when computing the metric " + "value.") + y_true = tf.cast(y_true, tf.int32) y_pred = tf.cast(y_pred, tf.int32) # true positive diff --git a/tensorflow_addons/metrics/r_square.py b/tensorflow_addons/metrics/r_square.py index 161d8e6e4d..6e5c99e8c0 100644 --- a/tensorflow_addons/metrics/r_square.py +++ b/tensorflow_addons/metrics/r_square.py @@ -14,6 +14,8 @@ # ============================================================================== """Implements R^2 scores.""" +import warnings + import tensorflow as tf from tensorflow.keras.metrics import Metric @@ -53,7 +55,12 @@ def __init__( self.res = self.add_weight("residual", initializer="zeros") self.count = self.add_weight("count", initializer="zeros") - def update_state(self, y_true, y_pred): + def update_state(self, y_true, y_pred, sample_weight=None): + if sample_weight is not None: + warnings.warn( + "`sample_weight` is not None. Be aware that RSquare " + "does not take `sample_weight` into account when computing the metric " + "value.") y_true = tf.convert_to_tensor(y_true, tf.float32) y_pred = tf.convert_to_tensor(y_pred, tf.float32) self.squared_sum.assign_add(tf.reduce_sum(y_true ** 2)) From b1196563b2e883487c11d88d3c76a107edad075a Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Fri, 6 Mar 2020 12:15:08 +0100 Subject: [PATCH 3/3] Black. --- tensorflow_addons/metrics/metrics_test.py | 3 ++- tensorflow_addons/metrics/multilabel_confusion_matrix.py | 3 ++- tensorflow_addons/metrics/r_square.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/metrics/metrics_test.py b/tensorflow_addons/metrics/metrics_test.py index 407297d551..1fe25f56e6 100644 --- a/tensorflow_addons/metrics/metrics_test.py +++ b/tensorflow_addons/metrics/metrics_test.py @@ -34,7 +34,8 @@ def check_update_state_signature(metric_class): "Class {} is missing the parameter {} in the `update_state` " "method. If the method doesn't use this argument, declare " "it anyway and raise a UserWarning if it is " - "not None.".format(metric_class.__name__, expected_parameter)) + "not None.".format(metric_class.__name__, expected_parameter) + ) if __name__ == "__main__": diff --git a/tensorflow_addons/metrics/multilabel_confusion_matrix.py b/tensorflow_addons/metrics/multilabel_confusion_matrix.py index 5694c52681..c84c01e067 100644 --- a/tensorflow_addons/metrics/multilabel_confusion_matrix.py +++ b/tensorflow_addons/metrics/multilabel_confusion_matrix.py @@ -111,7 +111,8 @@ def update_state(self, y_true, y_pred, sample_weight=None): warnings.warn( "`sample_weight` is not None. Be aware that MultiLabelConfusionMatrix " "does not take `sample_weight` into account when computing the metric " - "value.") + "value." + ) y_true = tf.cast(y_true, tf.int32) y_pred = tf.cast(y_pred, tf.int32) diff --git a/tensorflow_addons/metrics/r_square.py b/tensorflow_addons/metrics/r_square.py index 6e5c99e8c0..1cd885f077 100644 --- a/tensorflow_addons/metrics/r_square.py +++ b/tensorflow_addons/metrics/r_square.py @@ -60,7 +60,8 @@ def update_state(self, y_true, y_pred, sample_weight=None): warnings.warn( "`sample_weight` is not None. Be aware that RSquare " "does not take `sample_weight` into account when computing the metric " - "value.") + "value." + ) y_true = tf.convert_to_tensor(y_true, tf.float32) y_pred = tf.convert_to_tensor(y_pred, tf.float32) self.squared_sum.assign_add(tf.reduce_sum(y_true ** 2))