Skip to content

Commit 9ed027d

Browse files
autoihseanpmorgan
authored andcommitted
avoid substraction ops to maintain precision (#557)
* avoid substraction ops to preserve the precision
1 parent 9485cf5 commit 9ed027d

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tensorflow_addons/metrics/multilabel_confusion_matrix.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,18 @@ def __init__(self,
100100
def update_state(self, y_true, y_pred):
101101
y_true = tf.cast(y_true, tf.int32)
102102
y_pred = tf.cast(y_pred, tf.int32)
103-
103+
# true positive
104104
true_positive = tf.math.count_nonzero(y_true * y_pred, 0)
105105
# predictions sum
106106
pred_sum = tf.math.count_nonzero(y_pred, 0)
107107
# true labels sum
108108
true_sum = tf.math.count_nonzero(y_true, 0)
109109
false_positive = pred_sum - true_positive
110110
false_negative = true_sum - true_positive
111-
true_negative = y_true.get_shape(
112-
)[0] - true_positive - false_positive - false_negative
111+
y_true_negative = tf.math.not_equal(y_true, 1)
112+
y_pred_negative = tf.math.not_equal(y_pred, 1)
113+
true_negative = tf.math.count_nonzero(
114+
tf.math.logical_and(y_true_negative, y_pred_negative), axis=0)
113115

114116
# true positive state update
115117
self.true_positives.assign_add(tf.cast(true_positive, self.dtype))

0 commit comments

Comments
 (0)