Skip to content

Commit 6779473

Browse files
Removed run_all_in_graph_and_eager_mode in sparsemax. (#1689)
* Removed run_all_in_graph_and_eager_mode in sparsemax. * Removed old code. * Removed useless astype.
1 parent dc2bdfb commit 6779473

File tree

1 file changed

+101
-95
lines changed

1 file changed

+101
-95
lines changed

tensorflow_addons/activations/tests/sparsemax_test.py

+101-95
Original file line numberDiff line numberDiff line change
@@ -79,140 +79,146 @@ def test_sparsemax_against_numpy_low_rank(dtype):
7979
assert np_sparsemax.shape == tf_sparsemax_out.shape
8080

8181

82-
@test_utils.run_all_with_types(["float32", "float64"])
83-
@test_utils.run_all_in_graph_and_eager_modes
84-
class SparsemaxTest(tf.test.TestCase):
85-
def _tf_sparsemax(self, z, dtype, **kwargs):
86-
tf_sparsemax_op = sparsemax(z.astype(dtype), **kwargs)
87-
tf_sparsemax_out = self.evaluate(tf_sparsemax_op)
88-
89-
return tf_sparsemax_op, tf_sparsemax_out
82+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
83+
def test_sparsemax_against_numpy(dtype):
84+
"""check sparsemax kernel against numpy."""
85+
random = np.random.RandomState(1)
9086

91-
def test_sparsemax_against_numpy(self, dtype=None):
92-
"""check sparsemax kernel against numpy."""
93-
random = np.random.RandomState(1)
87+
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
9488

95-
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
89+
tf_sparsemax_out = sparsemax(z.astype(dtype))
90+
np_sparsemax = _np_sparsemax(z).astype(dtype)
9691

97-
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype)
98-
np_sparsemax = _np_sparsemax(z).astype(dtype)
92+
test_utils.assert_allclose_according_to_type(np_sparsemax, tf_sparsemax_out)
9993

100-
self.assertAllCloseAccordingToType(np_sparsemax, tf_sparsemax_out)
101-
self.assertShapeEqual(np_sparsemax, tf_sparsemax_op)
10294

103-
def test_sparsemax_against_numpy_high_rank(self, dtype=None):
104-
"""check sparsemax kernel against numpy."""
105-
random = np.random.RandomState(1)
95+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
96+
def test_sparsemax_against_numpy_high_rank(dtype):
97+
"""check sparsemax kernel against numpy."""
98+
random = np.random.RandomState(1)
10699

107-
z = random.uniform(low=-3, high=3, size=(test_obs, test_obs, 10))
100+
z = random.uniform(low=-3, high=3, size=(test_obs, test_obs, 10))
108101

109-
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype)
110-
np_sparsemax = np.reshape(
111-
_np_sparsemax(np.reshape(z, [test_obs * test_obs, 10])),
112-
[test_obs, test_obs, 10],
113-
).astype(dtype)
102+
tf_sparsemax_out = sparsemax(z.astype(dtype))
103+
np_sparsemax = np.reshape(
104+
_np_sparsemax(np.reshape(z, [test_obs * test_obs, 10])),
105+
[test_obs, test_obs, 10],
106+
).astype(dtype)
114107

115-
self.assertAllCloseAccordingToType(np_sparsemax, tf_sparsemax_out)
116-
self.assertShapeEqual(np_sparsemax, tf_sparsemax_op)
108+
test_utils.assert_allclose_according_to_type(np_sparsemax, tf_sparsemax_out)
117109

118-
def test_sparsemax_of_nan(self, dtype=None):
119-
"""check sparsemax transfers nan."""
120-
z_nan = np.asarray(
121-
[[0, np.nan, 0], [0, np.nan, np.nan], [np.nan, np.nan, np.nan],]
122-
).astype(dtype)
123110

124-
_, tf_sparsemax_nan = self._tf_sparsemax(z_nan, dtype)
125-
self.assertAllEqual(
111+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
112+
def test_sparsemax_of_nan(dtype):
113+
"""check sparsemax transfers nan."""
114+
z_nan = np.asarray(
115+
[[0, np.nan, 0], [0, np.nan, np.nan], [np.nan, np.nan, np.nan],]
116+
).astype(dtype)
117+
118+
tf_sparsemax_nan = sparsemax(z_nan)
119+
np.testing.assert_equal(
120+
np.array(
126121
[
127122
[np.nan, np.nan, np.nan],
128123
[np.nan, np.nan, np.nan],
129124
[np.nan, np.nan, np.nan],
130-
],
131-
tf_sparsemax_nan,
132-
)
125+
]
126+
),
127+
tf_sparsemax_nan,
128+
)
133129

134-
def test_sparsemax_of_inf(self, dtype=None):
135-
"""check sparsemax is infinity safe."""
136-
z_neg = np.asarray(
137-
[[0, -np.inf, 0], [0, -np.inf, -np.inf], [-np.inf, -np.inf, -np.inf],]
138-
).astype(dtype)
139-
z_pos = np.asarray(
140-
[[0, np.inf, 0], [0, np.inf, np.inf], [np.inf, np.inf, np.inf]]
141-
).astype(dtype)
142-
z_mix = np.asarray(
143-
[[0, np.inf, 0], [0, np.inf, -np.inf], [-np.inf, np.inf, -np.inf]]
144-
).astype(dtype)
145-
146-
_, tf_sparsemax_neg = self._tf_sparsemax(z_neg, dtype)
147-
self.assertAllEqual(
148-
[[0.5, 0, 0.5], [1, 0, 0], [np.nan, np.nan, np.nan]], tf_sparsemax_neg
149-
)
150130

151-
_, tf_sparsemax_pos = self._tf_sparsemax(z_pos, dtype)
152-
self.assertAllEqual(
131+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
132+
def test_sparsemax_of_inf(dtype):
133+
"""check sparsemax is infinity safe."""
134+
z_neg = np.asarray(
135+
[[0, -np.inf, 0], [0, -np.inf, -np.inf], [-np.inf, -np.inf, -np.inf],]
136+
).astype(dtype)
137+
z_pos = np.asarray(
138+
[[0, np.inf, 0], [0, np.inf, np.inf], [np.inf, np.inf, np.inf]]
139+
).astype(dtype)
140+
z_mix = np.asarray(
141+
[[0, np.inf, 0], [0, np.inf, -np.inf], [-np.inf, np.inf, -np.inf]]
142+
).astype(dtype)
143+
144+
tf_sparsemax_neg = sparsemax(z_neg)
145+
np.testing.assert_equal(
146+
np.array([[0.5, 0, 0.5], [1, 0, 0], [np.nan, np.nan, np.nan]]), tf_sparsemax_neg
147+
)
148+
149+
tf_sparsemax_pos = sparsemax(z_pos)
150+
np.testing.assert_equal(
151+
np.array(
153152
[
154153
[np.nan, np.nan, np.nan],
155154
[np.nan, np.nan, np.nan],
156155
[np.nan, np.nan, np.nan],
157-
],
158-
tf_sparsemax_pos,
159-
)
156+
]
157+
),
158+
tf_sparsemax_pos,
159+
)
160160

161-
_, tf_sparsemax_mix = self._tf_sparsemax(z_mix, dtype)
162-
self.assertAllEqual(
161+
tf_sparsemax_mix = sparsemax(z_mix)
162+
np.testing.assert_equal(
163+
np.array(
163164
[
164165
[np.nan, np.nan, np.nan],
165166
[np.nan, np.nan, np.nan],
166167
[np.nan, np.nan, np.nan],
167-
],
168-
tf_sparsemax_mix,
169-
)
168+
]
169+
),
170+
tf_sparsemax_mix,
171+
)
172+
173+
174+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
175+
def test_sparsemax_of_zero(dtype):
176+
"""check sparsemax proposition 1, part 1."""
177+
z = np.zeros((1, 10))
170178

171-
def test_sparsemax_of_zero(self, dtype=None):
172-
"""check sparsemax proposition 1, part 1."""
173-
z = np.zeros((1, 10))
179+
tf_sparsemax_out = sparsemax(z.astype(dtype))
180+
np_sparsemax = np.ones_like(z, dtype=dtype) / z.size
174181

175-
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype)
176-
np_sparsemax = np.ones_like(z, dtype=dtype) / z.size
182+
test_utils.assert_allclose_according_to_type(np_sparsemax, tf_sparsemax_out)
177183

178-
self.assertAllCloseAccordingToType(np_sparsemax, tf_sparsemax_out)
179-
self.assertShapeEqual(np_sparsemax, tf_sparsemax_op)
180184

181-
def test_sparsemax_of_to_inf(self, dtype=None):
182-
"""check sparsemax proposition 1, part 2."""
183-
random = np.random.RandomState(4)
185+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
186+
def test_sparsemax_of_to_inf(dtype):
187+
"""check sparsemax proposition 1, part 2."""
188+
random = np.random.RandomState(4)
184189

185-
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
190+
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
186191

187-
# assume |A(z)| = 1, as z is continues random
188-
z_sort_arg = np.argsort(z, axis=1)[:, ::-1]
189-
z_sort = np.sort(z, axis=-1)[:, ::-1]
190-
gamma_z = z_sort[:, 0] - z_sort[:, 1]
191-
epsilon = (0.99 * gamma_z * 1).reshape(-1, 1)
192+
# assume |A(z)| = 1, as z is continues random
193+
z_sort_arg = np.argsort(z, axis=1)[:, ::-1]
194+
z_sort = np.sort(z, axis=-1)[:, ::-1]
195+
gamma_z = z_sort[:, 0] - z_sort[:, 1]
196+
epsilon = (0.99 * gamma_z * 1).reshape(-1, 1)
192197

193-
# construct the expected 1_A(z) array
194-
p_expected = np.zeros((test_obs, 10), dtype=dtype)
195-
p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1
198+
# construct the expected 1_A(z) array
199+
p_expected = np.zeros((test_obs, 10), dtype=dtype)
200+
p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1
196201

197-
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax((1 / epsilon) * z, dtype)
202+
tf_sparsemax_out = sparsemax(((1 / epsilon) * z).astype(dtype))
198203

199-
self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out)
200-
self.assertShapeEqual(p_expected, tf_sparsemax_op)
204+
test_utils.assert_allclose_according_to_type(p_expected, tf_sparsemax_out)
201205

202-
def test_constant_add(self, dtype=None):
203-
"""check sparsemax proposition 2."""
204-
random = np.random.RandomState(5)
205206

206-
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
207-
c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype)
207+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
208+
def test_constant_add(dtype):
209+
"""check sparsemax proposition 2."""
210+
random = np.random.RandomState(5)
208211

209-
_, tf_sparsemax_zpc = self._tf_sparsemax(z + c, dtype)
212+
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
213+
c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype)
210214

211-
_, tf_sparsemax_z = self._tf_sparsemax(z, dtype)
215+
tf_sparsemax_zpc = sparsemax((z + c))
212216

213-
self.assertAllCloseAccordingToType(
214-
tf_sparsemax_zpc, tf_sparsemax_z, half_atol=5e-3
215-
)
217+
tf_sparsemax_z = sparsemax(z)
218+
219+
test_utils.assert_allclose_according_to_type(
220+
tf_sparsemax_zpc, tf_sparsemax_z, half_atol=5e-3
221+
)
216222

217223

218224
@pytest.mark.parametrize("dtype", ["float32", "float64"])

0 commit comments

Comments
 (0)