|
17 | 17 | import sys
|
18 | 18 |
|
19 | 19 | import pytest
|
20 |
| -from absl.testing import parameterized |
21 | 20 | import numpy as np
|
22 |
| -import tensorflow as tf |
23 | 21 | from tensorflow_addons.layers.netvlad import NetVLAD
|
24 | 22 | from tensorflow_addons.utils import test_utils
|
25 | 23 |
|
26 | 24 |
|
27 |
| -@test_utils.run_all_in_graph_and_eager_modes |
28 |
| -class NetVLADTest(tf.test.TestCase, parameterized.TestCase): |
29 |
| - """Tests for NetVLAD.""" |
| 25 | +pytestmark = pytest.mark.usefixtures("maybe_run_functions_eagerly") |
30 | 26 |
|
31 |
| - @parameterized.parameters( |
32 |
| - {"num_clusters": 1}, {"num_clusters": 4}, |
| 27 | + |
| 28 | +@pytest.mark.parametrize("num_clusters", [1, 4]) |
| 29 | +def test_simple(num_clusters): |
| 30 | + test_utils.layer_test( |
| 31 | + NetVLAD, |
| 32 | + kwargs={"num_clusters": num_clusters}, |
| 33 | + input_shape=(5, 4, 100), |
| 34 | + expected_output_shape=(None, num_clusters * 100), |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +def test_unknown(): |
| 39 | + inputs = np.random.random((5, 4, 100)).astype("float32") |
| 40 | + test_utils.layer_test( |
| 41 | + NetVLAD, |
| 42 | + kwargs={"num_clusters": 3}, |
| 43 | + input_shape=(None, None, 100), |
| 44 | + input_data=inputs, |
| 45 | + expected_output_shape=(None, 3 * 100), |
33 | 46 | )
|
34 |
| - def test_simple(self, num_clusters): |
| 47 | + |
| 48 | + |
| 49 | +def test_invalid_shape(): |
| 50 | + with pytest.raises(ValueError) as exception_info: |
35 | 51 | test_utils.layer_test(
|
36 |
| - NetVLAD, |
37 |
| - kwargs={"num_clusters": num_clusters}, |
38 |
| - input_shape=(5, 4, 100), |
39 |
| - expected_output_shape=(None, num_clusters * 100), |
| 52 | + NetVLAD, kwargs={"num_clusters": 0}, input_shape=(5, 4, 20) |
40 | 53 | )
|
| 54 | + assert "`num_clusters` must be greater than 1" in str(exception_info.value) |
41 | 55 |
|
42 |
| - def test_unknown(self): |
43 |
| - inputs = np.random.random((5, 4, 100)).astype("float32") |
| 56 | + with pytest.raises(ValueError) as exception_info: |
44 | 57 | test_utils.layer_test(
|
45 |
| - NetVLAD, |
46 |
| - kwargs={"num_clusters": 3}, |
47 |
| - input_shape=(None, None, 100), |
48 |
| - input_data=inputs, |
49 |
| - expected_output_shape=(None, 3 * 100), |
| 58 | + NetVLAD, kwargs={"num_clusters": 2}, input_shape=(5, 4, 4, 20) |
50 | 59 | )
|
51 |
| - |
52 |
| - def test_invalid_shape(self): |
53 |
| - with self.assertRaisesRegexp( |
54 |
| - ValueError, r"`num_clusters` must be greater than 1" |
55 |
| - ): |
56 |
| - test_utils.layer_test( |
57 |
| - NetVLAD, kwargs={"num_clusters": 0}, input_shape=(5, 4, 20) |
58 |
| - ) |
59 |
| - |
60 |
| - with self.assertRaisesRegexp(ValueError, r"must have rank 3"): |
61 |
| - test_utils.layer_test( |
62 |
| - NetVLAD, kwargs={"num_clusters": 2}, input_shape=(5, 4, 4, 20) |
63 |
| - ) |
| 60 | + assert "must have rank 3" in str(exception_info.value) |
64 | 61 |
|
65 | 62 |
|
66 | 63 | if __name__ == "__main__":
|
|
0 commit comments