Skip to content

Commit b170487

Browse files
Removed run_all_in_graph_and_eager_mode in r_square_test.py (#1356)
1 parent d4c2404 commit b170487

File tree

1 file changed

+97
-95
lines changed

1 file changed

+97
-95
lines changed

tensorflow_addons/metrics/r_square_test.py

+97-95
Original file line numberDiff line numberDiff line change
@@ -22,101 +22,103 @@
2222
from sklearn.metrics import r2_score as sklearn_r2_score
2323
from tensorflow_addons.metrics import RSquare
2424
from tensorflow_addons.metrics.r_square import VALID_MULTIOUTPUT
25-
from tensorflow_addons.utils import test_utils
26-
27-
28-
@test_utils.run_all_in_graph_and_eager_modes
29-
class RSquareTest(tf.test.TestCase):
30-
def test_config(self):
31-
r2_obj = RSquare(name="r_square")
32-
self.assertEqual(r2_obj.name, "r_square")
33-
self.assertEqual(r2_obj.dtype, tf.float32)
34-
# Check save and restore config
35-
r2_obj2 = RSquare.from_config(r2_obj.get_config())
36-
self.assertEqual(r2_obj2.name, "r_square")
37-
self.assertEqual(r2_obj2.dtype, tf.float32)
38-
39-
def initialize_vars(self, y_shape=(), multioutput: str = "uniform_average"):
40-
r2_obj = RSquare(y_shape=y_shape, multioutput=multioutput)
41-
self.evaluate(tf.compat.v1.variables_initializer(r2_obj.variables))
42-
return r2_obj
43-
44-
def update_obj_states(self, obj, actuals, preds, sample_weight=None):
45-
update_op = obj.update_state(actuals, preds, sample_weight=sample_weight)
46-
self.evaluate(update_op)
47-
48-
def check_results(self, obj, value):
49-
self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5)
50-
51-
def test_r2_perfect_score(self):
52-
actuals = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
53-
preds = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
54-
actuals = tf.cast(actuals, dtype=tf.float32)
55-
preds = tf.cast(preds, dtype=tf.float32)
56-
# Initialize
57-
r2_obj = self.initialize_vars()
58-
# Update
59-
self.update_obj_states(r2_obj, actuals, preds)
60-
# Check results
61-
self.check_results(r2_obj, 1.0)
62-
63-
def test_r2_worst_score(self):
64-
actuals = tf.constant([10, 600, 4, 9.77], dtype=tf.float32)
65-
preds = tf.constant([1, 70, 40, 5.7], dtype=tf.float32)
66-
actuals = tf.cast(actuals, dtype=tf.float32)
67-
preds = tf.cast(preds, dtype=tf.float32)
68-
# Initialize
69-
r2_obj = self.initialize_vars()
70-
# Update
71-
self.update_obj_states(r2_obj, actuals, preds)
72-
# Check results
73-
self.check_results(r2_obj, -0.073607)
74-
75-
def test_r2_random_score(self):
76-
actuals = tf.constant([10, 600, 3, 9.77], dtype=tf.float32)
77-
preds = tf.constant([1, 340, 40, 5.7], dtype=tf.float32)
78-
actuals = tf.cast(actuals, dtype=tf.float32)
79-
preds = tf.cast(preds, dtype=tf.float32)
80-
# Initialize
81-
r2_obj = self.initialize_vars()
82-
# Update
83-
self.update_obj_states(r2_obj, actuals, preds)
84-
# Check results
85-
self.check_results(r2_obj, 0.7376327)
86-
87-
def test_r2_sklearn_comparison(self):
88-
"""Test that RSquare behaves similarly to the scikit-learn
89-
implementation of the same metric, given random input.
90-
"""
91-
for multioutput in VALID_MULTIOUTPUT:
92-
for i in range(10):
93-
actuals = np.random.rand(64, 3)
94-
preds = np.random.rand(64, 3)
95-
sample_weight = np.random.rand(64, 1)
96-
tensor_actuals = tf.constant(actuals, dtype=tf.float32)
97-
tensor_preds = tf.constant(preds, dtype=tf.float32)
98-
tensor_sample_weight = tf.constant(sample_weight, dtype=tf.float32)
99-
tensor_actuals = tf.cast(tensor_actuals, dtype=tf.float32)
100-
tensor_preds = tf.cast(tensor_preds, dtype=tf.float32)
101-
tensor_sample_weight = tf.cast(tensor_sample_weight, dtype=tf.float32)
102-
# Initialize
103-
r2_obj = self.initialize_vars(y_shape=(3,), multioutput=multioutput)
104-
# Update
105-
self.update_obj_states(
106-
r2_obj,
107-
tensor_actuals,
108-
tensor_preds,
109-
sample_weight=tensor_sample_weight,
110-
)
111-
# Check results by comparing to results of scikit-learn r2 implementation
112-
sklearn_result = sklearn_r2_score(
113-
actuals, preds, sample_weight=sample_weight, multioutput=multioutput
114-
)
115-
self.check_results(r2_obj, sklearn_result)
116-
117-
def test_unrecognized_multioutput(self):
118-
with pytest.raises(ValueError):
119-
self.initialize_vars(multioutput="meadian")
25+
26+
27+
def test_config():
28+
r2_obj = RSquare(name="r_square")
29+
assert r2_obj.name == "r_square"
30+
assert r2_obj.dtype == tf.float32
31+
# Check save and restore config
32+
r2_obj2 = RSquare.from_config(r2_obj.get_config())
33+
assert r2_obj2.name == "r_square"
34+
assert r2_obj2.dtype == tf.float32
35+
36+
37+
def initialize_vars(y_shape=(), multioutput: str = "uniform_average"):
38+
return RSquare(y_shape=y_shape, multioutput=multioutput)
39+
40+
41+
def update_obj_states(obj, actuals, preds, sample_weight=None):
42+
obj.update_state(actuals, preds, sample_weight=sample_weight)
43+
44+
45+
def check_results(obj, value):
46+
np.testing.assert_allclose(value, obj.result(), atol=1e-5)
47+
48+
49+
def test_r2_perfect_score():
50+
actuals = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
51+
preds = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
52+
actuals = tf.cast(actuals, dtype=tf.float32)
53+
preds = tf.cast(preds, dtype=tf.float32)
54+
# Initialize
55+
r2_obj = initialize_vars()
56+
# Update
57+
update_obj_states(r2_obj, actuals, preds)
58+
# Check results
59+
check_results(r2_obj, 1.0)
60+
61+
62+
def test_r2_worst_score():
63+
actuals = tf.constant([10, 600, 4, 9.77], dtype=tf.float32)
64+
preds = tf.constant([1, 70, 40, 5.7], dtype=tf.float32)
65+
actuals = tf.cast(actuals, dtype=tf.float32)
66+
preds = tf.cast(preds, dtype=tf.float32)
67+
# Initialize
68+
r2_obj = initialize_vars()
69+
# Update
70+
update_obj_states(r2_obj, actuals, preds)
71+
# Check results
72+
check_results(r2_obj, -0.073607)
73+
74+
75+
def test_r2_random_score():
76+
actuals = tf.constant([10, 600, 3, 9.77], dtype=tf.float32)
77+
preds = tf.constant([1, 340, 40, 5.7], dtype=tf.float32)
78+
actuals = tf.cast(actuals, dtype=tf.float32)
79+
preds = tf.cast(preds, dtype=tf.float32)
80+
# Initialize
81+
r2_obj = initialize_vars()
82+
# Update
83+
update_obj_states(r2_obj, actuals, preds)
84+
# Check results
85+
check_results(r2_obj, 0.7376327)
86+
87+
88+
def test_r2_sklearn_comparison():
89+
"""Test that RSquare behaves similarly to the scikit-learn
90+
implementation of the same metric, given random input.
91+
"""
92+
for multioutput in VALID_MULTIOUTPUT:
93+
for i in range(10):
94+
actuals = np.random.rand(64, 3)
95+
preds = np.random.rand(64, 3)
96+
sample_weight = np.random.rand(64, 1)
97+
tensor_actuals = tf.constant(actuals, dtype=tf.float32)
98+
tensor_preds = tf.constant(preds, dtype=tf.float32)
99+
tensor_sample_weight = tf.constant(sample_weight, dtype=tf.float32)
100+
tensor_actuals = tf.cast(tensor_actuals, dtype=tf.float32)
101+
tensor_preds = tf.cast(tensor_preds, dtype=tf.float32)
102+
tensor_sample_weight = tf.cast(tensor_sample_weight, dtype=tf.float32)
103+
# Initialize
104+
r2_obj = initialize_vars(y_shape=(3,), multioutput=multioutput)
105+
# Update
106+
update_obj_states(
107+
r2_obj,
108+
tensor_actuals,
109+
tensor_preds,
110+
sample_weight=tensor_sample_weight,
111+
)
112+
# Check results by comparing to results of scikit-learn r2 implementation
113+
sklearn_result = sklearn_r2_score(
114+
actuals, preds, sample_weight=sample_weight, multioutput=multioutput
115+
)
116+
check_results(r2_obj, sklearn_result)
117+
118+
119+
def test_unrecognized_multioutput():
120+
with pytest.raises(ValueError):
121+
initialize_vars(multioutput="meadian")
120122

121123

122124
if __name__ == "__main__":

0 commit comments

Comments
 (0)