Skip to content

Commit 369b8f2

Browse files
Moved tests out of run_in_graph_and_eager_mode in connected components. (#1412)
See #1328
1 parent 2cd85ae commit 369b8f2

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

tensorflow_addons/image/connected_components_test.py

+29-27
Original file line numberDiff line numberDiff line change
@@ -41,36 +41,38 @@
4141
)
4242

4343

44-
@test_utils.run_all_in_graph_and_eager_modes
45-
class ConnectedComponentsTest(tf.test.TestCase):
46-
def testDisconnected(self):
47-
arr = tf.cast(
48-
[
49-
[1, 0, 0, 1, 0, 0, 0, 0, 1],
50-
[0, 1, 0, 0, 0, 1, 0, 1, 0],
51-
[1, 0, 1, 0, 0, 0, 1, 0, 0],
52-
[0, 0, 0, 0, 1, 0, 0, 0, 0],
53-
[0, 0, 1, 0, 0, 0, 0, 0, 0],
54-
],
55-
tf.bool,
56-
)
57-
expected = [
58-
[1, 0, 0, 2, 0, 0, 0, 0, 3],
59-
[0, 4, 0, 0, 0, 5, 0, 6, 0],
60-
[7, 0, 8, 0, 0, 0, 9, 0, 0],
61-
[0, 0, 0, 0, 10, 0, 0, 0, 0],
62-
[0, 0, 11, 0, 0, 0, 0, 0, 0],
63-
]
64-
self.assertAllEqual(self.evaluate(connected_components(arr)), expected)
44+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
45+
def test_disconnected():
46+
arr = tf.cast(
47+
[
48+
[1, 0, 0, 1, 0, 0, 0, 0, 1],
49+
[0, 1, 0, 0, 0, 1, 0, 1, 0],
50+
[1, 0, 1, 0, 0, 0, 1, 0, 0],
51+
[0, 0, 0, 0, 1, 0, 0, 0, 0],
52+
[0, 0, 1, 0, 0, 0, 0, 0, 0],
53+
],
54+
tf.bool,
55+
)
56+
expected = [
57+
[1, 0, 0, 2, 0, 0, 0, 0, 3],
58+
[0, 4, 0, 0, 0, 5, 0, 6, 0],
59+
[7, 0, 8, 0, 0, 0, 9, 0, 0],
60+
[0, 0, 0, 0, 10, 0, 0, 0, 0],
61+
[0, 0, 11, 0, 0, 0, 0, 0, 0],
62+
]
63+
np.testing.assert_equal(connected_components(arr).numpy(), expected)
6564

66-
def testSimple(self):
67-
arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]]
6865

69-
# Single component with id 1.
70-
self.assertAllEqual(
71-
self.evaluate(connected_components(tf.cast(arr, tf.bool))), arr
72-
)
66+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
67+
def test_simple():
68+
arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]]
7369

70+
# Single component with id 1.
71+
np.testing.assert_equal(connected_components(tf.cast(arr, tf.bool)).numpy(), arr)
72+
73+
74+
@test_utils.run_all_in_graph_and_eager_modes
75+
class ConnectedComponentsTest(tf.test.TestCase):
7476
def testSnake(self):
7577

7678
# Single component with id 1.

0 commit comments

Comments
 (0)