Skip to content

Commit 591a597

Browse files
Moved test_sequence_loss outside run_all_in_graph_and_eager_mode. (#1384)
1 parent f0cef01 commit 591a597

File tree

1 file changed

+34
-50
lines changed

1 file changed

+34
-50
lines changed

tensorflow_addons/seq2seq/loss_test.py

+34-50
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,40 @@ def test_sequence_loss(average_across_timesteps, average_across_batch, zero_weig
102102
np.testing.assert_allclose(computed, expected, rtol=1e-6, atol=1e-6)
103103

104104

105+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
106+
@pytest.mark.parametrize("average_across_timesteps", [True, False])
107+
@pytest.mark.parametrize("average_across_batch", [True, False])
108+
def test_sequence_loss_class(average_across_timesteps, average_across_batch):
109+
110+
(
111+
batch_size,
112+
sequence_length,
113+
_,
114+
logits,
115+
targets,
116+
weights,
117+
expected_loss,
118+
) = get_test_data()
119+
seq_loss = loss.SequenceLoss(
120+
average_across_timesteps=average_across_timesteps,
121+
average_across_batch=average_across_batch,
122+
sum_over_timesteps=False,
123+
sum_over_batch=False,
124+
)
125+
average_loss_per_example = seq_loss(targets, logits, weights)
126+
res = average_loss_per_example.numpy()
127+
if average_across_timesteps and average_across_batch:
128+
expected = expected_loss
129+
elif not average_across_timesteps and average_across_batch:
130+
expected = np.full(sequence_length, expected_loss)
131+
elif average_across_timesteps and not average_across_batch:
132+
expected = np.full(batch_size, expected_loss)
133+
elif not average_across_timesteps and not average_across_batch:
134+
expected = np.full((batch_size, sequence_length), expected_loss)
135+
136+
np.testing.assert_allclose(res, expected, atol=1e-6, rtol=1e-6)
137+
138+
105139
@test_utils.run_all_in_graph_and_eager_modes
106140
class LossTest(tf.test.TestCase):
107141
def setup(self):
@@ -128,56 +162,6 @@ def setup(self):
128162
# and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5]
129163
self.expected_loss = 1.60944
130164

131-
def testSequenceLossClass(self):
132-
with self.cached_session(use_gpu=True):
133-
self.setup()
134-
seq_loss = loss.SequenceLoss(
135-
average_across_timesteps=True,
136-
average_across_batch=True,
137-
sum_over_timesteps=False,
138-
sum_over_batch=False,
139-
)
140-
average_loss_per_example = seq_loss(self.targets, self.logits, self.weights)
141-
res = self.evaluate(average_loss_per_example)
142-
self.assertAllClose(self.expected_loss, res)
143-
144-
seq_loss = loss.SequenceLoss(
145-
average_across_timesteps=False,
146-
average_across_batch=True,
147-
sum_over_timesteps=False,
148-
sum_over_batch=False,
149-
)
150-
average_loss_per_sequence = seq_loss(
151-
self.targets, self.logits, self.weights
152-
)
153-
res = self.evaluate(average_loss_per_sequence)
154-
compare_per_sequence = np.full((self.sequence_length), self.expected_loss)
155-
self.assertAllClose(compare_per_sequence, res)
156-
157-
seq_loss = loss.SequenceLoss(
158-
average_across_timesteps=True,
159-
average_across_batch=False,
160-
sum_over_timesteps=False,
161-
sum_over_batch=False,
162-
)
163-
average_loss_per_batch = seq_loss(self.targets, self.logits, self.weights)
164-
res = self.evaluate(average_loss_per_batch)
165-
compare_per_batch = np.full((self.batch_size), self.expected_loss)
166-
self.assertAllClose(compare_per_batch, res)
167-
168-
seq_loss = loss.SequenceLoss(
169-
average_across_timesteps=False,
170-
average_across_batch=False,
171-
sum_over_timesteps=False,
172-
sum_over_batch=False,
173-
)
174-
total_loss = seq_loss(self.targets, self.logits, self.weights)
175-
res = self.evaluate(total_loss)
176-
compare_total = np.full(
177-
(self.batch_size, self.sequence_length), self.expected_loss
178-
)
179-
self.assertAllClose(compare_total, res)
180-
181165
def testSumReduction(self):
182166
with self.cached_session(use_gpu=True):
183167
self.setup()

0 commit comments

Comments
 (0)