@@ -230,28 +230,6 @@ def test_permutation(self, dtype=None):
230
230
)
231
231
self .assertShapeEqual (p_expected , tf_sparsemax_op )
232
232
233
- def test_diffrence (self , dtype = None ):
234
- """check sparsemax proposition 4."""
235
- random = np .random .RandomState (7 )
236
-
237
- z = random .uniform (low = - 3 , high = 3 , size = (test_obs , 10 ))
238
- _ , p = self ._tf_sparsemax (z , dtype )
239
-
240
- etol = {"float16" : 1e-2 , "float32" : 1e-6 , "float64" : 1e-9 }[dtype ]
241
-
242
- for val in range (0 , test_obs ):
243
- for i in range (0 , 10 ):
244
- for j in range (0 , 10 ):
245
- # check condition, the obesite pair will be checked anyway
246
- if z [val , i ] > z [val , j ]:
247
- continue
248
-
249
- self .assertTrue (
250
- 0 <= p [val , j ] - p [val , i ] <= z [val , j ] - z [val , i ] + etol ,
251
- "0 <= %.10f <= %.10f"
252
- % (p [val , j ] - p [val , i ], z [val , j ] - z [val , i ] + etol ),
253
- )
254
-
255
233
256
234
@pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
257
235
def test_two_dimentional (dtype ):
@@ -270,6 +248,26 @@ def test_two_dimentional(dtype):
270
248
assert z .shape == tf_sparsemax_out .shape
271
249
272
250
251
+ @pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
252
+ def test_diffrence (dtype ):
253
+ """check sparsemax proposition 4."""
254
+ random = np .random .RandomState (7 )
255
+
256
+ z = random .uniform (low = - 3 , high = 3 , size = (test_obs , 10 ))
257
+ p = sparsemax (z .astype (dtype )).numpy ()
258
+
259
+ etol = {np .float32 : 1e-6 , np .float64 : 1e-9 }[dtype ]
260
+
261
+ for val in range (0 , test_obs ):
262
+ for i in range (0 , 10 ):
263
+ for j in range (0 , 10 ):
264
+ # check condition, the obesite pair will be checked anyway
265
+ if z [val , i ] > z [val , j ]:
266
+ continue
267
+
268
+ assert 0 <= p [val , j ] - p [val , i ] <= z [val , j ] - z [val , i ] + etol
269
+
270
+
273
271
@pytest .mark .parametrize ("dtype" , ["float32" , "float64" ])
274
272
def test_gradient_against_estimate (dtype ):
275
273
"""check sparsemax Rop, against estimated Rop."""
0 commit comments