Skip to content

Moved test_sequence_loss outside run_all_in_graph_and_eager_mode. #1384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 34 additions & 50 deletions tensorflow_addons/seq2seq/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,40 @@ def test_sequence_loss(average_across_timesteps, average_across_batch, zero_weig
np.testing.assert_allclose(computed, expected, rtol=1e-6, atol=1e-6)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("average_across_timesteps", [True, False])
@pytest.mark.parametrize("average_across_batch", [True, False])
def test_sequence_loss_class(average_across_timesteps, average_across_batch):

(
batch_size,
sequence_length,
_,
logits,
targets,
weights,
expected_loss,
) = get_test_data()
seq_loss = loss.SequenceLoss(
average_across_timesteps=average_across_timesteps,
average_across_batch=average_across_batch,
sum_over_timesteps=False,
sum_over_batch=False,
)
average_loss_per_example = seq_loss(targets, logits, weights)
res = average_loss_per_example.numpy()
if average_across_timesteps and average_across_batch:
expected = expected_loss
elif not average_across_timesteps and average_across_batch:
expected = np.full(sequence_length, expected_loss)
elif average_across_timesteps and not average_across_batch:
expected = np.full(batch_size, expected_loss)
elif not average_across_timesteps and not average_across_batch:
expected = np.full((batch_size, sequence_length), expected_loss)

np.testing.assert_allclose(res, expected, atol=1e-6, rtol=1e-6)


@test_utils.run_all_in_graph_and_eager_modes
class LossTest(tf.test.TestCase):
def setup(self):
Expand All @@ -128,56 +162,6 @@ def setup(self):
# and logits = [[0.5] * 5, [1.5] * 5, [2.5] * 5]
self.expected_loss = 1.60944

def testSequenceLossClass(self):
with self.cached_session(use_gpu=True):
self.setup()
seq_loss = loss.SequenceLoss(
average_across_timesteps=True,
average_across_batch=True,
sum_over_timesteps=False,
sum_over_batch=False,
)
average_loss_per_example = seq_loss(self.targets, self.logits, self.weights)
res = self.evaluate(average_loss_per_example)
self.assertAllClose(self.expected_loss, res)

seq_loss = loss.SequenceLoss(
average_across_timesteps=False,
average_across_batch=True,
sum_over_timesteps=False,
sum_over_batch=False,
)
average_loss_per_sequence = seq_loss(
self.targets, self.logits, self.weights
)
res = self.evaluate(average_loss_per_sequence)
compare_per_sequence = np.full((self.sequence_length), self.expected_loss)
self.assertAllClose(compare_per_sequence, res)

seq_loss = loss.SequenceLoss(
average_across_timesteps=True,
average_across_batch=False,
sum_over_timesteps=False,
sum_over_batch=False,
)
average_loss_per_batch = seq_loss(self.targets, self.logits, self.weights)
res = self.evaluate(average_loss_per_batch)
compare_per_batch = np.full((self.batch_size), self.expected_loss)
self.assertAllClose(compare_per_batch, res)

seq_loss = loss.SequenceLoss(
average_across_timesteps=False,
average_across_batch=False,
sum_over_timesteps=False,
sum_over_batch=False,
)
total_loss = seq_loss(self.targets, self.logits, self.weights)
res = self.evaluate(total_loss)
compare_total = np.full(
(self.batch_size, self.sequence_length), self.expected_loss
)
self.assertAllClose(compare_total, res)

def testSumReduction(self):
with self.cached_session(use_gpu=True):
self.setup()
Expand Down