Skip to content

Commit f6404ee

Browse files
Moved test out of run_in_graph_and_eager_mode in contrastive_test.py (#1445)
See #1328
1 parent 53f26d1 commit f6404ee

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

tensorflow_addons/losses/contrastive_test.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import sys
1818

19+
import numpy as np
1920
import pytest
2021
import tensorflow as tf
2122
from tensorflow_addons.losses import contrastive
@@ -132,21 +133,22 @@ def test_no_reduction(self):
132133

133134
self.assertAllClose(loss, [0.81, 0.49, 1.69, 0.49, 0.0, 0.25])
134135

135-
def test_sum_reduction(self):
136-
cl_obj = contrastive.ContrastiveLoss(reduction=tf.keras.losses.Reduction.SUM)
137-
y_true = tf.constant([0, 0, 1, 1, 0, 1], dtype=tf.dtypes.int64)
138-
y_pred = tf.constant([0.1, 0.3, 1.3, 0.7, 1.1, 0.5], dtype=tf.dtypes.float32)
139-
loss = cl_obj(y_true, y_pred)
140136

141-
# Loss = y * (y`)^2 + (1 - y) * (max(m - y`, 0))^2
142-
# = [max(1 - 0.1, 0)^2, max(1 - 0.3, 0)^2,
143-
# 1.3^2, 0.7^2, max(1 - 1.1, 0)^2, 0.5^2]
144-
# = [0.9^2, 0.7^2, 1.3^2, 0.7^2, 0^2, 0.5^2]
145-
# = [0.81, 0.49, 1.69, 0.49, 0, 0.25]
146-
# Reduced loss = 0.81 + 0.49 + 1.69 + 0.49 + 0 + 0.25
147-
# = 3.73
137+
def test_sum_reduction():
138+
cl_obj = contrastive.ContrastiveLoss(reduction=tf.keras.losses.Reduction.SUM)
139+
y_true = tf.constant([0, 0, 1, 1, 0, 1], dtype=tf.dtypes.int64)
140+
y_pred = tf.constant([0.1, 0.3, 1.3, 0.7, 1.1, 0.5], dtype=tf.dtypes.float32)
141+
loss = cl_obj(y_true, y_pred)
148142

149-
self.assertAllClose(loss, 3.73)
143+
# Loss = y * (y`)^2 + (1 - y) * (max(m - y`, 0))^2
144+
# = [max(1 - 0.1, 0)^2, max(1 - 0.3, 0)^2,
145+
# 1.3^2, 0.7^2, max(1 - 1.1, 0)^2, 0.5^2]
146+
# = [0.9^2, 0.7^2, 1.3^2, 0.7^2, 0^2, 0.5^2]
147+
# = [0.81, 0.49, 1.69, 0.49, 0, 0.25]
148+
# Reduced loss = 0.81 + 0.49 + 1.69 + 0.49 + 0 + 0.25
149+
# = 3.73
150+
151+
np.testing.assert_allclose(loss, 3.73)
150152

151153

152154
if __name__ == "__main__":

0 commit comments

Comments
 (0)