Skip to content

Commit 04e1edf

Browse files
Removed run_all_in_graph_and_eager_mode in polynomial_test.py (#1337)
* Removed run_all_in_eager_and_graph_mode. * Add the decorator to run in eager mode too.
1 parent 91f67b8 commit 04e1edf

File tree

1 file changed

+37
-35
lines changed

1 file changed

+37
-35
lines changed

tensorflow_addons/layers/polynomial_test.py

+37-35
Original file line numberDiff line numberDiff line change
@@ -21,41 +21,43 @@
2121
import tensorflow as tf
2222

2323
from tensorflow_addons.layers.polynomial import PolynomialCrossing
24-
from tensorflow_addons.utils import test_utils
25-
26-
27-
@test_utils.run_all_in_graph_and_eager_modes
28-
class PolynomialCrossingTest(tf.test.TestCase):
29-
# Do not use layer_test due to multiple inputs.
30-
31-
def test_full_matrix(self):
32-
x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32)
33-
x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32)
34-
layer = PolynomialCrossing(projection_dim=None, kernel_initializer="ones")
35-
output = layer([x0, x])
36-
self.evaluate(tf.compat.v1.global_variables_initializer())
37-
self.assertAllClose(np.asarray([[0.55, 0.8, 1.05]]), output)
38-
39-
def test_invalid_proj_dim(self):
40-
with self.assertRaisesRegexp(ValueError, r"is not supported yet"):
41-
x0 = np.random.random((12, 5))
42-
x = np.random.random((12, 5))
43-
layer = PolynomialCrossing(projection_dim=6)
44-
layer([x0, x])
45-
46-
def test_invalid_inputs(self):
47-
with self.assertRaisesRegexp(ValueError, r"must be a tuple or list of size 2"):
48-
x0 = np.random.random((12, 5))
49-
x = np.random.random((12, 5))
50-
x1 = np.random.random((12, 5))
51-
layer = PolynomialCrossing(projection_dim=6)
52-
layer([x0, x, x1])
53-
54-
def test_serialization(self):
55-
layer = PolynomialCrossing(projection_dim=None)
56-
serialized_layer = tf.keras.layers.serialize(layer)
57-
new_layer = tf.keras.layers.deserialize(serialized_layer)
58-
self.assertEqual(layer.get_config(), new_layer.get_config())
24+
25+
26+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
27+
def test_full_matrix():
28+
x0 = np.asarray([[0.1, 0.2, 0.3]]).astype(np.float32)
29+
x = np.asarray([[0.4, 0.5, 0.6]]).astype(np.float32)
30+
layer = PolynomialCrossing(projection_dim=None, kernel_initializer="ones")
31+
output = layer([x0, x])
32+
np.testing.assert_allclose([[0.55, 0.8, 1.05]], output)
33+
34+
35+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
36+
def test_invalid_proj_dim():
37+
with pytest.raises(ValueError) as exception_info:
38+
x0 = np.random.random((12, 5))
39+
x = np.random.random((12, 5))
40+
layer = PolynomialCrossing(projection_dim=6)
41+
layer([x0, x])
42+
assert "is not supported yet" in str(exception_info.value)
43+
44+
45+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
46+
def test_invalid_inputs():
47+
with pytest.raises(ValueError) as exception_info:
48+
x0 = np.random.random((12, 5))
49+
x = np.random.random((12, 5))
50+
x1 = np.random.random((12, 5))
51+
layer = PolynomialCrossing(projection_dim=6)
52+
layer([x0, x, x1])
53+
assert "must be a tuple or list of size 2" in str(exception_info.value)
54+
55+
56+
def test_serialization():
57+
layer = PolynomialCrossing(projection_dim=None)
58+
serialized_layer = tf.keras.layers.serialize(layer)
59+
new_layer = tf.keras.layers.deserialize(serialized_layer)
60+
assert layer.get_config() == new_layer.get_config()
5961

6062

6163
if __name__ == "__main__":

0 commit comments

Comments
 (0)