Skip to content

Commit 2cd85ae

Browse files
Moved test out of run_in_graph_and_eager_mode in sparsemax. (#1408)
1 parent 9993804 commit 2cd85ae

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

tensorflow_addons/activations/sparsemax_test.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -230,28 +230,6 @@ def test_permutation(self, dtype=None):
230230
)
231231
self.assertShapeEqual(p_expected, tf_sparsemax_op)
232232

233-
def test_diffrence(self, dtype=None):
234-
"""check sparsemax proposition 4."""
235-
random = np.random.RandomState(7)
236-
237-
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
238-
_, p = self._tf_sparsemax(z, dtype)
239-
240-
etol = {"float16": 1e-2, "float32": 1e-6, "float64": 1e-9}[dtype]
241-
242-
for val in range(0, test_obs):
243-
for i in range(0, 10):
244-
for j in range(0, 10):
245-
# check condition, the obesite pair will be checked anyway
246-
if z[val, i] > z[val, j]:
247-
continue
248-
249-
self.assertTrue(
250-
0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol,
251-
"0 <= %.10f <= %.10f"
252-
% (p[val, j] - p[val, i], z[val, j] - z[val, i] + etol),
253-
)
254-
255233

256234
@pytest.mark.parametrize("dtype", ["float32", "float64"])
257235
def test_two_dimentional(dtype):
@@ -270,6 +248,26 @@ def test_two_dimentional(dtype):
270248
assert z.shape == tf_sparsemax_out.shape
271249

272250

251+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
252+
def test_diffrence(dtype):
253+
"""check sparsemax proposition 4."""
254+
random = np.random.RandomState(7)
255+
256+
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
257+
p = sparsemax(z.astype(dtype)).numpy()
258+
259+
etol = {np.float32: 1e-6, np.float64: 1e-9}[dtype]
260+
261+
for val in range(0, test_obs):
262+
for i in range(0, 10):
263+
for j in range(0, 10):
264+
# check condition, the obesite pair will be checked anyway
265+
if z[val, i] > z[val, j]:
266+
continue
267+
268+
assert 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol
269+
270+
273271
@pytest.mark.parametrize("dtype", ["float32", "float64"])
274272
def test_gradient_against_estimate(dtype):
275273
"""check sparsemax Rop, against estimated Rop."""

0 commit comments

Comments
 (0)