Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tensorflow_addons/metrics/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,15 @@ py_test(
":metrics",
],
)

py_test(
name = "metrics_test",
size = "small",
srcs = [
"metrics_test.py",
],
main = "metrics_test.py",
deps = [
":metrics",
],
)
42 changes: 42 additions & 0 deletions tensorflow_addons/metrics/metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import unittest
import inspect

from tensorflow.keras.metrics import Metric
from tensorflow_addons import metrics


class MetricsTests(unittest.TestCase):
def test_update_state_signature(self):
for name, obj in inspect.getmembers(metrics):
if inspect.isclass(obj) and issubclass(obj, Metric):
check_update_state_signature(obj)


def check_update_state_signature(metric_class):
update_state_signature = inspect.signature(metric_class.update_state)
for expected_parameter in ["y_true", "y_pred", "sample_weight"]:
if expected_parameter not in update_state_signature.parameters.keys():
raise ValueError(
"Class {} is missing the parameter {} in the `update_state` "
"method. If the method doesn't use this argument, declare "
"it anyway and raise a UserWarning if it is "
"not None.".format(metric_class.__name__, expected_parameter)
)


if __name__ == "__main__":
unittest.main()
11 changes: 10 additions & 1 deletion tensorflow_addons/metrics/multilabel_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Implements Multi-label confusion matrix scores."""

import warnings

import tensorflow as tf
from tensorflow.keras.metrics import Metric
import numpy as np
Expand Down Expand Up @@ -104,7 +106,14 @@ def __init__(
dtype=self.dtype,
)

def update_state(self, y_true, y_pred):
def update_state(self, y_true, y_pred, sample_weight=None):
if sample_weight is not None:
warnings.warn(
"`sample_weight` is not None. Be aware that MultiLabelConfusionMatrix "
"does not take `sample_weight` into account when computing the metric "
"value."
)

y_true = tf.cast(y_true, tf.int32)
y_pred = tf.cast(y_pred, tf.int32)
# true positive
Expand Down
10 changes: 9 additions & 1 deletion tensorflow_addons/metrics/r_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Implements R^2 scores."""

import warnings

import tensorflow as tf
from tensorflow.keras.metrics import Metric

Expand Down Expand Up @@ -53,7 +55,13 @@ def __init__(
self.res = self.add_weight("residual", initializer="zeros")
self.count = self.add_weight("count", initializer="zeros")

def update_state(self, y_true, y_pred):
def update_state(self, y_true, y_pred, sample_weight=None):
if sample_weight is not None:
warnings.warn(
"`sample_weight` is not None. Be aware that RSquare "
"does not take `sample_weight` into account when computing the metric "
"value."
)
y_true = tf.convert_to_tensor(y_true, tf.float32)
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
self.squared_sum.assign_add(tf.reduce_sum(y_true ** 2))
Expand Down