Skip to content

Commit 83c1ba2

Browse files
Moved test out of run_all_in_graph_and_eager_mode in softshrink. (#1404)
See #1328
1 parent 92df7c9 commit 83c1ba2

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

tensorflow_addons/activations/softshrink_test.py

+19-17
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,28 @@
2828
from tensorflow_addons.utils import test_utils
2929

3030

31-
@test_utils.run_all_in_graph_and_eager_modes
32-
class SoftshrinkTest(tf.test.TestCase, parameterized.TestCase):
33-
def test_invalid(self):
34-
with self.assertRaisesOpError("lower must be less than or equal to upper."):
35-
y = _softshrink_custom_op(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
36-
self.evaluate(y)
37-
38-
@parameterized.named_parameters(
39-
("float16", np.float16), ("float32", np.float32), ("float64", np.float64)
31+
def test_invalid():
32+
with pytest.raises(
33+
tf.errors.OpError, match="lower must be less than or equal to upper."
34+
):
35+
y = _softshrink_custom_op(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
36+
y.numpy()
37+
38+
39+
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
40+
def test_softshrink(dtype):
41+
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
42+
expected_result = tf.constant([-1.5, -0.5, 0.0, 0.5, 1.5], dtype=dtype)
43+
test_utils.assert_allclose_according_to_type(softshrink(x), expected_result)
44+
45+
expected_result = tf.constant([-1.0, 0.0, 0.0, 0.0, 1.0], dtype=dtype)
46+
test_utils.assert_allclose_according_to_type(
47+
softshrink(x, lower=-1.0, upper=1.0), expected_result
4048
)
41-
def test_softshrink(self, dtype):
42-
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
43-
expected_result = tf.constant([-1.5, -0.5, 0.0, 0.5, 1.5], dtype=dtype)
44-
self.assertAllCloseAccordingToType(softshrink(x), expected_result)
4549

46-
expected_result = tf.constant([-1.0, 0.0, 0.0, 0.0, 1.0], dtype=dtype)
47-
self.assertAllCloseAccordingToType(
48-
softshrink(x, lower=-1.0, upper=1.0), expected_result
49-
)
5050

51+
@test_utils.run_all_in_graph_and_eager_modes
52+
class SoftshrinkTest(tf.test.TestCase, parameterized.TestCase):
5153
@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
5254
def test_theoretical_gradients(self, dtype):
5355
# Only test theoretical gradients for float32 and float64

0 commit comments

Comments
 (0)