|
41 | 41 | )
|
42 | 42 |
|
43 | 43 |
|
44 |
| -@test_utils.run_all_in_graph_and_eager_modes |
45 |
| -class ConnectedComponentsTest(tf.test.TestCase): |
46 |
| - def testDisconnected(self): |
47 |
| - arr = tf.cast( |
48 |
| - [ |
49 |
| - [1, 0, 0, 1, 0, 0, 0, 0, 1], |
50 |
| - [0, 1, 0, 0, 0, 1, 0, 1, 0], |
51 |
| - [1, 0, 1, 0, 0, 0, 1, 0, 0], |
52 |
| - [0, 0, 0, 0, 1, 0, 0, 0, 0], |
53 |
| - [0, 0, 1, 0, 0, 0, 0, 0, 0], |
54 |
| - ], |
55 |
| - tf.bool, |
56 |
| - ) |
57 |
| - expected = [ |
58 |
| - [1, 0, 0, 2, 0, 0, 0, 0, 3], |
59 |
| - [0, 4, 0, 0, 0, 5, 0, 6, 0], |
60 |
| - [7, 0, 8, 0, 0, 0, 9, 0, 0], |
61 |
| - [0, 0, 0, 0, 10, 0, 0, 0, 0], |
62 |
| - [0, 0, 11, 0, 0, 0, 0, 0, 0], |
63 |
| - ] |
64 |
| - self.assertAllEqual(self.evaluate(connected_components(arr)), expected) |
| 44 | +@pytest.mark.usefixtures("maybe_run_functions_eagerly") |
| 45 | +def test_disconnected(): |
| 46 | + arr = tf.cast( |
| 47 | + [ |
| 48 | + [1, 0, 0, 1, 0, 0, 0, 0, 1], |
| 49 | + [0, 1, 0, 0, 0, 1, 0, 1, 0], |
| 50 | + [1, 0, 1, 0, 0, 0, 1, 0, 0], |
| 51 | + [0, 0, 0, 0, 1, 0, 0, 0, 0], |
| 52 | + [0, 0, 1, 0, 0, 0, 0, 0, 0], |
| 53 | + ], |
| 54 | + tf.bool, |
| 55 | + ) |
| 56 | + expected = [ |
| 57 | + [1, 0, 0, 2, 0, 0, 0, 0, 3], |
| 58 | + [0, 4, 0, 0, 0, 5, 0, 6, 0], |
| 59 | + [7, 0, 8, 0, 0, 0, 9, 0, 0], |
| 60 | + [0, 0, 0, 0, 10, 0, 0, 0, 0], |
| 61 | + [0, 0, 11, 0, 0, 0, 0, 0, 0], |
| 62 | + ] |
| 63 | + np.testing.assert_equal(connected_components(arr).numpy(), expected) |
65 | 64 |
|
66 |
| - def testSimple(self): |
67 |
| - arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] |
68 | 65 |
|
69 |
| - # Single component with id 1. |
70 |
| - self.assertAllEqual( |
71 |
| - self.evaluate(connected_components(tf.cast(arr, tf.bool))), arr |
72 |
| - ) |
| 66 | +@pytest.mark.usefixtures("maybe_run_functions_eagerly") |
| 67 | +def test_simple(): |
| 68 | + arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] |
73 | 69 |
|
| 70 | + # Single component with id 1. |
| 71 | + np.testing.assert_equal(connected_components(tf.cast(arr, tf.bool)).numpy(), arr) |
| 72 | + |
| 73 | + |
| 74 | +@test_utils.run_all_in_graph_and_eager_modes |
| 75 | +class ConnectedComponentsTest(tf.test.TestCase): |
74 | 76 | def testSnake(self):
|
75 | 77 |
|
76 | 78 | # Single component with id 1.
|
|
0 commit comments