@@ -79,140 +79,146 @@ def test_sparsemax_against_numpy_low_rank(dtype):
79
79
assert np_sparsemax .shape == tf_sparsemax_out .shape
80
80
81
81
82
- @test_utils .run_all_with_types (["float32" , "float64" ])
83
- @test_utils .run_all_in_graph_and_eager_modes
84
- class SparsemaxTest (tf .test .TestCase ):
85
- def _tf_sparsemax (self , z , dtype , ** kwargs ):
86
- tf_sparsemax_op = sparsemax (z .astype (dtype ), ** kwargs )
87
- tf_sparsemax_out = self .evaluate (tf_sparsemax_op )
88
-
89
- return tf_sparsemax_op , tf_sparsemax_out
82
+ @pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
83
+ def test_sparsemax_against_numpy (dtype ):
84
+ """check sparsemax kernel against numpy."""
85
+ random = np .random .RandomState (1 )
90
86
91
- def test_sparsemax_against_numpy (self , dtype = None ):
92
- """check sparsemax kernel against numpy."""
93
- random = np .random .RandomState (1 )
87
+ z = random .uniform (low = - 3 , high = 3 , size = (test_obs , 10 ))
94
88
95
- z = random .uniform (low = - 3 , high = 3 , size = (test_obs , 10 ))
89
+ tf_sparsemax_out = sparsemax (z .astype (dtype ))
90
+ np_sparsemax = _np_sparsemax (z ).astype (dtype )
96
91
97
- tf_sparsemax_op , tf_sparsemax_out = self ._tf_sparsemax (z , dtype )
98
- np_sparsemax = _np_sparsemax (z ).astype (dtype )
92
+ test_utils .assert_allclose_according_to_type (np_sparsemax , tf_sparsemax_out )
99
93
100
- self .assertAllCloseAccordingToType (np_sparsemax , tf_sparsemax_out )
101
- self .assertShapeEqual (np_sparsemax , tf_sparsemax_op )
102
94
103
- def test_sparsemax_against_numpy_high_rank (self , dtype = None ):
104
- """check sparsemax kernel against numpy."""
105
- random = np .random .RandomState (1 )
95
+ @pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
96
+ def test_sparsemax_against_numpy_high_rank (dtype ):
97
+ """check sparsemax kernel against numpy."""
98
+ random = np .random .RandomState (1 )
106
99
107
- z = random .uniform (low = - 3 , high = 3 , size = (test_obs , test_obs , 10 ))
100
+ z = random .uniform (low = - 3 , high = 3 , size = (test_obs , test_obs , 10 ))
108
101
109
- tf_sparsemax_op , tf_sparsemax_out = self . _tf_sparsemax ( z , dtype )
110
- np_sparsemax = np .reshape (
111
- _np_sparsemax (np .reshape (z , [test_obs * test_obs , 10 ])),
112
- [test_obs , test_obs , 10 ],
113
- ).astype (dtype )
102
+ tf_sparsemax_out = sparsemax ( z . astype ( dtype ) )
103
+ np_sparsemax = np .reshape (
104
+ _np_sparsemax (np .reshape (z , [test_obs * test_obs , 10 ])),
105
+ [test_obs , test_obs , 10 ],
106
+ ).astype (dtype )
114
107
115
- self .assertAllCloseAccordingToType (np_sparsemax , tf_sparsemax_out )
116
- self .assertShapeEqual (np_sparsemax , tf_sparsemax_op )
108
+ test_utils .assert_allclose_according_to_type (np_sparsemax , tf_sparsemax_out )
117
109
118
- def test_sparsemax_of_nan (self , dtype = None ):
119
- """check sparsemax transfers nan."""
120
- z_nan = np .asarray (
121
- [[0 , np .nan , 0 ], [0 , np .nan , np .nan ], [np .nan , np .nan , np .nan ],]
122
- ).astype (dtype )
123
110
124
- _ , tf_sparsemax_nan = self ._tf_sparsemax (z_nan , dtype )
125
- self .assertAllEqual (
111
+ @pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
112
+ def test_sparsemax_of_nan (dtype ):
113
+ """check sparsemax transfers nan."""
114
+ z_nan = np .asarray (
115
+ [[0 , np .nan , 0 ], [0 , np .nan , np .nan ], [np .nan , np .nan , np .nan ],]
116
+ ).astype (dtype )
117
+
118
+ tf_sparsemax_nan = sparsemax (z_nan )
119
+ np .testing .assert_equal (
120
+ np .array (
126
121
[
127
122
[np .nan , np .nan , np .nan ],
128
123
[np .nan , np .nan , np .nan ],
129
124
[np .nan , np .nan , np .nan ],
130
- ],
131
- tf_sparsemax_nan ,
132
- )
125
+ ]
126
+ ),
127
+ tf_sparsemax_nan ,
128
+ )
133
129
134
- def test_sparsemax_of_inf (self , dtype = None ):
135
- """check sparsemax is infinity safe."""
136
- z_neg = np .asarray (
137
- [[0 , - np .inf , 0 ], [0 , - np .inf , - np .inf ], [- np .inf , - np .inf , - np .inf ],]
138
- ).astype (dtype )
139
- z_pos = np .asarray (
140
- [[0 , np .inf , 0 ], [0 , np .inf , np .inf ], [np .inf , np .inf , np .inf ]]
141
- ).astype (dtype )
142
- z_mix = np .asarray (
143
- [[0 , np .inf , 0 ], [0 , np .inf , - np .inf ], [- np .inf , np .inf , - np .inf ]]
144
- ).astype (dtype )
145
-
146
- _ , tf_sparsemax_neg = self ._tf_sparsemax (z_neg , dtype )
147
- self .assertAllEqual (
148
- [[0.5 , 0 , 0.5 ], [1 , 0 , 0 ], [np .nan , np .nan , np .nan ]], tf_sparsemax_neg
149
- )
150
130
151
- _ , tf_sparsemax_pos = self ._tf_sparsemax (z_pos , dtype )
152
- self .assertAllEqual (
131
+ @pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
132
+ def test_sparsemax_of_inf (dtype ):
133
+ """check sparsemax is infinity safe."""
134
+ z_neg = np .asarray (
135
+ [[0 , - np .inf , 0 ], [0 , - np .inf , - np .inf ], [- np .inf , - np .inf , - np .inf ],]
136
+ ).astype (dtype )
137
+ z_pos = np .asarray (
138
+ [[0 , np .inf , 0 ], [0 , np .inf , np .inf ], [np .inf , np .inf , np .inf ]]
139
+ ).astype (dtype )
140
+ z_mix = np .asarray (
141
+ [[0 , np .inf , 0 ], [0 , np .inf , - np .inf ], [- np .inf , np .inf , - np .inf ]]
142
+ ).astype (dtype )
143
+
144
+ tf_sparsemax_neg = sparsemax (z_neg )
145
+ np .testing .assert_equal (
146
+ np .array ([[0.5 , 0 , 0.5 ], [1 , 0 , 0 ], [np .nan , np .nan , np .nan ]]), tf_sparsemax_neg
147
+ )
148
+
149
+ tf_sparsemax_pos = sparsemax (z_pos )
150
+ np .testing .assert_equal (
151
+ np .array (
153
152
[
154
153
[np .nan , np .nan , np .nan ],
155
154
[np .nan , np .nan , np .nan ],
156
155
[np .nan , np .nan , np .nan ],
157
- ],
158
- tf_sparsemax_pos ,
159
- )
156
+ ]
157
+ ),
158
+ tf_sparsemax_pos ,
159
+ )
160
160
161
- _ , tf_sparsemax_mix = self ._tf_sparsemax (z_mix , dtype )
162
- self .assertAllEqual (
161
+ tf_sparsemax_mix = sparsemax (z_mix )
162
+ np .testing .assert_equal (
163
+ np .array (
163
164
[
164
165
[np .nan , np .nan , np .nan ],
165
166
[np .nan , np .nan , np .nan ],
166
167
[np .nan , np .nan , np .nan ],
167
- ],
168
- tf_sparsemax_mix ,
169
- )
168
+ ]
169
+ ),
170
+ tf_sparsemax_mix ,
171
+ )
172
+
173
+
174
+ @pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
175
+ def test_sparsemax_of_zero (dtype ):
176
+ """check sparsemax proposition 1, part 1."""
177
+ z = np .zeros ((1 , 10 ))
170
178
171
- def test_sparsemax_of_zero (self , dtype = None ):
172
- """check sparsemax proposition 1, part 1."""
173
- z = np .zeros ((1 , 10 ))
179
+ tf_sparsemax_out = sparsemax (z .astype (dtype ))
180
+ np_sparsemax = np .ones_like (z , dtype = dtype ) / z .size
174
181
175
- tf_sparsemax_op , tf_sparsemax_out = self ._tf_sparsemax (z , dtype )
176
- np_sparsemax = np .ones_like (z , dtype = dtype ) / z .size
182
+ test_utils .assert_allclose_according_to_type (np_sparsemax , tf_sparsemax_out )
177
183
178
- self .assertAllCloseAccordingToType (np_sparsemax , tf_sparsemax_out )
179
- self .assertShapeEqual (np_sparsemax , tf_sparsemax_op )
180
184
181
- def test_sparsemax_of_to_inf (self , dtype = None ):
182
- """check sparsemax proposition 1, part 2."""
183
- random = np .random .RandomState (4 )
185
+ @pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
186
+ def test_sparsemax_of_to_inf (dtype ):
187
+ """check sparsemax proposition 1, part 2."""
188
+ random = np .random .RandomState (4 )
184
189
185
- z = random .uniform (low = - 3 , high = 3 , size = (test_obs , 10 ))
190
+ z = random .uniform (low = - 3 , high = 3 , size = (test_obs , 10 ))
186
191
187
- # assume |A(z)| = 1, as z is continues random
188
- z_sort_arg = np .argsort (z , axis = 1 )[:, ::- 1 ]
189
- z_sort = np .sort (z , axis = - 1 )[:, ::- 1 ]
190
- gamma_z = z_sort [:, 0 ] - z_sort [:, 1 ]
191
- epsilon = (0.99 * gamma_z * 1 ).reshape (- 1 , 1 )
192
+ # assume |A(z)| = 1, as z is continues random
193
+ z_sort_arg = np .argsort (z , axis = 1 )[:, ::- 1 ]
194
+ z_sort = np .sort (z , axis = - 1 )[:, ::- 1 ]
195
+ gamma_z = z_sort [:, 0 ] - z_sort [:, 1 ]
196
+ epsilon = (0.99 * gamma_z * 1 ).reshape (- 1 , 1 )
192
197
193
- # construct the expected 1_A(z) array
194
- p_expected = np .zeros ((test_obs , 10 ), dtype = dtype )
195
- p_expected [np .arange (0 , test_obs ), z_sort_arg [:, 0 ]] = 1
198
+ # construct the expected 1_A(z) array
199
+ p_expected = np .zeros ((test_obs , 10 ), dtype = dtype )
200
+ p_expected [np .arange (0 , test_obs ), z_sort_arg [:, 0 ]] = 1
196
201
197
- tf_sparsemax_op , tf_sparsemax_out = self . _tf_sparsemax (( 1 / epsilon ) * z , dtype )
202
+ tf_sparsemax_out = sparsemax ((( 1 / epsilon ) * z ). astype ( dtype ) )
198
203
199
- self .assertAllCloseAccordingToType (p_expected , tf_sparsemax_out )
200
- self .assertShapeEqual (p_expected , tf_sparsemax_op )
204
+ test_utils .assert_allclose_according_to_type (p_expected , tf_sparsemax_out )
201
205
202
- def test_constant_add (self , dtype = None ):
203
- """check sparsemax proposition 2."""
204
- random = np .random .RandomState (5 )
205
206
206
- z = random .uniform (low = - 3 , high = 3 , size = (test_obs , 10 )).astype (dtype )
207
- c = random .uniform (low = - 3 , high = 3 , size = (test_obs , 1 )).astype (dtype )
207
+ @pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
208
+ def test_constant_add (dtype ):
209
+ """check sparsemax proposition 2."""
210
+ random = np .random .RandomState (5 )
208
211
209
- _ , tf_sparsemax_zpc = self ._tf_sparsemax (z + c , dtype )
212
+ z = random .uniform (low = - 3 , high = 3 , size = (test_obs , 10 )).astype (dtype )
213
+ c = random .uniform (low = - 3 , high = 3 , size = (test_obs , 1 )).astype (dtype )
210
214
211
- _ , tf_sparsemax_z = self . _tf_sparsemax ( z , dtype )
215
+ tf_sparsemax_zpc = sparsemax (( z + c ) )
212
216
213
- self .assertAllCloseAccordingToType (
214
- tf_sparsemax_zpc , tf_sparsemax_z , half_atol = 5e-3
215
- )
217
+ tf_sparsemax_z = sparsemax (z )
218
+
219
+ test_utils .assert_allclose_according_to_type (
220
+ tf_sparsemax_zpc , tf_sparsemax_z , half_atol = 5e-3
221
+ )
216
222
217
223
218
224
@pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
0 commit comments