|
28 | 28 | from tensorflow_addons.utils import test_utils
|
29 | 29 |
|
30 | 30 |
|
31 |
| -@test_utils.run_all_in_graph_and_eager_modes |
32 |
| -class SoftshrinkTest(tf.test.TestCase, parameterized.TestCase): |
33 |
| - def test_invalid(self): |
34 |
| - with self.assertRaisesOpError("lower must be less than or equal to upper."): |
35 |
| - y = _softshrink_custom_op(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0) |
36 |
| - self.evaluate(y) |
37 |
| - |
38 |
| - @parameterized.named_parameters( |
39 |
| - ("float16", np.float16), ("float32", np.float32), ("float64", np.float64) |
| 31 | +def test_invalid(): |
| 32 | + with pytest.raises( |
| 33 | + tf.errors.OpError, match="lower must be less than or equal to upper." |
| 34 | + ): |
| 35 | + y = _softshrink_custom_op(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0) |
| 36 | + y.numpy() |
| 37 | + |
| 38 | + |
| 39 | +@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) |
| 40 | +def test_softshrink(dtype): |
| 41 | + x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) |
| 42 | + expected_result = tf.constant([-1.5, -0.5, 0.0, 0.5, 1.5], dtype=dtype) |
| 43 | + test_utils.assert_allclose_according_to_type(softshrink(x), expected_result) |
| 44 | + |
| 45 | + expected_result = tf.constant([-1.0, 0.0, 0.0, 0.0, 1.0], dtype=dtype) |
| 46 | + test_utils.assert_allclose_according_to_type( |
| 47 | + softshrink(x, lower=-1.0, upper=1.0), expected_result |
40 | 48 | )
|
41 |
| - def test_softshrink(self, dtype): |
42 |
| - x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) |
43 |
| - expected_result = tf.constant([-1.5, -0.5, 0.0, 0.5, 1.5], dtype=dtype) |
44 |
| - self.assertAllCloseAccordingToType(softshrink(x), expected_result) |
45 | 49 |
|
46 |
| - expected_result = tf.constant([-1.0, 0.0, 0.0, 0.0, 1.0], dtype=dtype) |
47 |
| - self.assertAllCloseAccordingToType( |
48 |
| - softshrink(x, lower=-1.0, upper=1.0), expected_result |
49 |
| - ) |
50 | 50 |
|
| 51 | +@test_utils.run_all_in_graph_and_eager_modes |
| 52 | +class SoftshrinkTest(tf.test.TestCase, parameterized.TestCase): |
51 | 53 | @parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
|
52 | 54 | def test_theoretical_gradients(self, dtype):
|
53 | 55 | # Only test theoretical gradients for float32 and float64
|
|
0 commit comments