Skip to content

Commit d4c2404

Browse files
Removed test_utils.run_all_in_graph_and_eager_modes in netvlad_test.py (#1350)
* Use pytest only for netvlad. * Used pytestmark.
1 parent 125d97d commit d4c2404

File tree

1 file changed

+29
-32
lines changed

1 file changed

+29
-32
lines changed

tensorflow_addons/layers/netvlad_test.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,50 +17,47 @@
1717
import sys
1818

1919
import pytest
20-
from absl.testing import parameterized
2120
import numpy as np
22-
import tensorflow as tf
2321
from tensorflow_addons.layers.netvlad import NetVLAD
2422
from tensorflow_addons.utils import test_utils
2523

2624

27-
@test_utils.run_all_in_graph_and_eager_modes
28-
class NetVLADTest(tf.test.TestCase, parameterized.TestCase):
29-
"""Tests for NetVLAD."""
25+
pytestmark = pytest.mark.usefixtures("maybe_run_functions_eagerly")
3026

31-
@parameterized.parameters(
32-
{"num_clusters": 1}, {"num_clusters": 4},
27+
28+
@pytest.mark.parametrize("num_clusters", [1, 4])
29+
def test_simple(num_clusters):
30+
test_utils.layer_test(
31+
NetVLAD,
32+
kwargs={"num_clusters": num_clusters},
33+
input_shape=(5, 4, 100),
34+
expected_output_shape=(None, num_clusters * 100),
35+
)
36+
37+
38+
def test_unknown():
39+
inputs = np.random.random((5, 4, 100)).astype("float32")
40+
test_utils.layer_test(
41+
NetVLAD,
42+
kwargs={"num_clusters": 3},
43+
input_shape=(None, None, 100),
44+
input_data=inputs,
45+
expected_output_shape=(None, 3 * 100),
3346
)
34-
def test_simple(self, num_clusters):
47+
48+
49+
def test_invalid_shape():
50+
with pytest.raises(ValueError) as exception_info:
3551
test_utils.layer_test(
36-
NetVLAD,
37-
kwargs={"num_clusters": num_clusters},
38-
input_shape=(5, 4, 100),
39-
expected_output_shape=(None, num_clusters * 100),
52+
NetVLAD, kwargs={"num_clusters": 0}, input_shape=(5, 4, 20)
4053
)
54+
assert "`num_clusters` must be greater than 1" in str(exception_info.value)
4155

42-
def test_unknown(self):
43-
inputs = np.random.random((5, 4, 100)).astype("float32")
56+
with pytest.raises(ValueError) as exception_info:
4457
test_utils.layer_test(
45-
NetVLAD,
46-
kwargs={"num_clusters": 3},
47-
input_shape=(None, None, 100),
48-
input_data=inputs,
49-
expected_output_shape=(None, 3 * 100),
58+
NetVLAD, kwargs={"num_clusters": 2}, input_shape=(5, 4, 4, 20)
5059
)
51-
52-
def test_invalid_shape(self):
53-
with self.assertRaisesRegexp(
54-
ValueError, r"`num_clusters` must be greater than 1"
55-
):
56-
test_utils.layer_test(
57-
NetVLAD, kwargs={"num_clusters": 0}, input_shape=(5, 4, 20)
58-
)
59-
60-
with self.assertRaisesRegexp(ValueError, r"must have rank 3"):
61-
test_utils.layer_test(
62-
NetVLAD, kwargs={"num_clusters": 2}, input_shape=(5, 4, 4, 20)
63-
)
60+
assert "must have rank 3" in str(exception_info.value)
6461

6562

6663
if __name__ == "__main__":

0 commit comments

Comments
 (0)