File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed
tensorflow_addons/metrics Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -100,16 +100,18 @@ def __init__(self,
100
100
def update_state (self , y_true , y_pred ):
101
101
y_true = tf .cast (y_true , tf .int32 )
102
102
y_pred = tf .cast (y_pred , tf .int32 )
103
-
103
+ # true positive
104
104
true_positive = tf .math .count_nonzero (y_true * y_pred , 0 )
105
105
# predictions sum
106
106
pred_sum = tf .math .count_nonzero (y_pred , 0 )
107
107
# true labels sum
108
108
true_sum = tf .math .count_nonzero (y_true , 0 )
109
109
false_positive = pred_sum - true_positive
110
110
false_negative = true_sum - true_positive
111
- true_negative = y_true .get_shape (
112
- )[0 ] - true_positive - false_positive - false_negative
111
+ y_true_negative = tf .math .not_equal (y_true , 1 )
112
+ y_pred_negative = tf .math .not_equal (y_pred , 1 )
113
+ true_negative = tf .math .count_nonzero (
114
+ tf .math .logical_and (y_true_negative , y_pred_negative ), axis = 0 )
113
115
114
116
# true positive state update
115
117
self .true_positives .assign_add (tf .cast (true_positive , self .dtype ))
You can’t perform that action at this time.
0 commit comments