diff --git a/tensorflow_addons/activations/sparsemax_test.py b/tensorflow_addons/activations/sparsemax_test.py index 7a41a275a6..0715506f96 100644 --- a/tensorflow_addons/activations/sparsemax_test.py +++ b/tensorflow_addons/activations/sparsemax_test.py @@ -210,26 +210,6 @@ def test_constant_add(self, dtype=None): tf_sparsemax_zpc, tf_sparsemax_z, half_atol=5e-3 ) - def test_permutation(self, dtype=None): - """check sparsemax proposition 3.""" - random = np.random.RandomState(6) - - z = random.uniform(low=-3, high=3, size=(test_obs, 10)) - _, p = self._tf_sparsemax(z, dtype) - - for i in range(test_obs): - per = random.permutation(10) - - tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax( - z[i, per].reshape(1, -1), dtype - ) - p_expected = p[i, per].reshape(1, -1) - - self.assertAllCloseAccordingToType( - p_expected, tf_sparsemax_out, half_atol=5e-3 - ) - self.assertShapeEqual(p_expected, tf_sparsemax_op) - @pytest.mark.parametrize("dtype", ["float32", "float64"]) def test_two_dimentional(dtype): @@ -268,6 +248,26 @@ def test_diffrence(dtype): assert 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_permutation(dtype): + """check sparsemax proposition 3.""" + random = np.random.RandomState(6) + + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + p = sparsemax(z.astype(dtype)).numpy() + + for i in range(test_obs): + per = random.permutation(10) + + tf_sparsemax_out = sparsemax(z[i, per].reshape(1, -1).astype(dtype)) + p_expected = p[i, per].reshape(1, -1) + + test_utils.assert_allclose_according_to_type( + p_expected, tf_sparsemax_out, half_atol=5e-3 + ) + assert p_expected.shape == tf_sparsemax_out.shape + + @pytest.mark.parametrize("dtype", ["float32", "float64"]) def test_gradient_against_estimate(dtype): """check sparsemax Rop, against estimated Rop."""