From 942c41115ee47ec5d9b99bc6e9345361441386eb Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Wed, 25 Mar 2020 10:41:28 +0000 Subject: [PATCH] Moved test_sequence_loss outside run_all_in_graph_and_eager_mode. --- tensorflow_addons/seq2seq/loss_test.py | 84 +++++++++++--------------- 1 file changed, 34 insertions(+), 50 deletions(-) diff --git a/tensorflow_addons/seq2seq/loss_test.py b/tensorflow_addons/seq2seq/loss_test.py index e2ef558e9e..b422bf02a5 100644 --- a/tensorflow_addons/seq2seq/loss_test.py +++ b/tensorflow_addons/seq2seq/loss_test.py @@ -102,6 +102,40 @@ def test_sequence_loss(average_across_timesteps, average_across_batch, zero_weig np.testing.assert_allclose(computed, expected, rtol=1e-6, atol=1e-6) +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("average_across_timesteps", [True, False]) +@pytest.mark.parametrize("average_across_batch", [True, False]) +def test_sequence_loss_class(average_across_timesteps, average_across_batch): + + ( + batch_size, + sequence_length, + _, + logits, + targets, + weights, + expected_loss, + ) = get_test_data() + seq_loss = loss.SequenceLoss( + average_across_timesteps=average_across_timesteps, + average_across_batch=average_across_batch, + sum_over_timesteps=False, + sum_over_batch=False, + ) + average_loss_per_example = seq_loss(targets, logits, weights) + res = average_loss_per_example.numpy() + if average_across_timesteps and average_across_batch: + expected = expected_loss + elif not average_across_timesteps and average_across_batch: + expected = np.full(sequence_length, expected_loss) + elif average_across_timesteps and not average_across_batch: + expected = np.full(batch_size, expected_loss) + elif not average_across_timesteps and not average_across_batch: + expected = np.full((batch_size, sequence_length), expected_loss) + + np.testing.assert_allclose(res, expected, atol=1e-6, rtol=1e-6) + + @test_utils.run_all_in_graph_and_eager_modes class LossTest(tf.test.TestCase): def setup(self): @@ -128,56 +162,6 @@ def setup(self): # and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5] self.expected_loss = 1.60944 - def testSequenceLossClass(self): - with self.cached_session(use_gpu=True): - self.setup() - seq_loss = loss.SequenceLoss( - average_across_timesteps=True, - average_across_batch=True, - sum_over_timesteps=False, - sum_over_batch=False, - ) - average_loss_per_example = seq_loss(self.targets, self.logits, self.weights) - res = self.evaluate(average_loss_per_example) - self.assertAllClose(self.expected_loss, res) - - seq_loss = loss.SequenceLoss( - average_across_timesteps=False, - average_across_batch=True, - sum_over_timesteps=False, - sum_over_batch=False, - ) - average_loss_per_sequence = seq_loss( - self.targets, self.logits, self.weights - ) - res = self.evaluate(average_loss_per_sequence) - compare_per_sequence = np.full((self.sequence_length), self.expected_loss) - self.assertAllClose(compare_per_sequence, res) - - seq_loss = loss.SequenceLoss( - average_across_timesteps=True, - average_across_batch=False, - sum_over_timesteps=False, - sum_over_batch=False, - ) - average_loss_per_batch = seq_loss(self.targets, self.logits, self.weights) - res = self.evaluate(average_loss_per_batch) - compare_per_batch = np.full((self.batch_size), self.expected_loss) - self.assertAllClose(compare_per_batch, res) - - seq_loss = loss.SequenceLoss( - average_across_timesteps=False, - average_across_batch=False, - sum_over_timesteps=False, - sum_over_batch=False, - ) - total_loss = seq_loss(self.targets, self.logits, self.weights) - res = self.evaluate(total_loss) - compare_total = np.full( - (self.batch_size, self.sequence_length), self.expected_loss - ) - self.assertAllClose(compare_total, res) - def testSumReduction(self): with self.cached_session(use_gpu=True): self.setup()