|
16 | 16 |
|
17 | 17 | import sys
|
18 | 18 |
|
| 19 | +import numpy as np |
19 | 20 | import pytest
|
20 | 21 | import tensorflow as tf
|
21 | 22 | from tensorflow_addons.losses import contrastive
|
@@ -132,21 +133,22 @@ def test_no_reduction(self):
|
132 | 133 |
|
133 | 134 | self.assertAllClose(loss, [0.81, 0.49, 1.69, 0.49, 0.0, 0.25])
|
134 | 135 |
|
135 |
| - def test_sum_reduction(self): |
136 |
| - cl_obj = contrastive.ContrastiveLoss(reduction=tf.keras.losses.Reduction.SUM) |
137 |
| - y_true = tf.constant([0, 0, 1, 1, 0, 1], dtype=tf.dtypes.int64) |
138 |
| - y_pred = tf.constant([0.1, 0.3, 1.3, 0.7, 1.1, 0.5], dtype=tf.dtypes.float32) |
139 |
| - loss = cl_obj(y_true, y_pred) |
140 | 136 |
|
141 |
| - # Loss = y * (y`)^2 + (1 - y) * (max(m - y`, 0))^2 |
142 |
| - # = [max(1 - 0.1, 0)^2, max(1 - 0.3, 0)^2, |
143 |
| - # 1.3^2, 0.7^2, max(1 - 1.1, 0)^2, 0.5^2] |
144 |
| - # = [0.9^2, 0.7^2, 1.3^2, 0.7^2, 0^2, 0.5^2] |
145 |
| - # = [0.81, 0.49, 1.69, 0.49, 0, 0.25] |
146 |
| - # Reduced loss = 0.81 + 0.49 + 1.69 + 0.49 + 0 + 0.25 |
147 |
| - # = 3.73 |
| 137 | +def test_sum_reduction(): |
| 138 | + cl_obj = contrastive.ContrastiveLoss(reduction=tf.keras.losses.Reduction.SUM) |
| 139 | + y_true = tf.constant([0, 0, 1, 1, 0, 1], dtype=tf.dtypes.int64) |
| 140 | + y_pred = tf.constant([0.1, 0.3, 1.3, 0.7, 1.1, 0.5], dtype=tf.dtypes.float32) |
| 141 | + loss = cl_obj(y_true, y_pred) |
148 | 142 |
|
149 |
| - self.assertAllClose(loss, 3.73) |
| 143 | + # Loss = y * (y`)^2 + (1 - y) * (max(m - y`, 0))^2 |
| 144 | + # = [max(1 - 0.1, 0)^2, max(1 - 0.3, 0)^2, |
| 145 | + # 1.3^2, 0.7^2, max(1 - 1.1, 0)^2, 0.5^2] |
| 146 | + # = [0.9^2, 0.7^2, 1.3^2, 0.7^2, 0^2, 0.5^2] |
| 147 | + # = [0.81, 0.49, 1.69, 0.49, 0, 0.25] |
| 148 | + # Reduced loss = 0.81 + 0.49 + 1.69 + 0.49 + 0 + 0.25 |
| 149 | + # = 3.73 |
| 150 | + |
| 151 | + np.testing.assert_allclose(loss, 3.73) |
150 | 152 |
|
151 | 153 |
|
152 | 154 | if __name__ == "__main__":
|
|
0 commit comments