diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index bc358b4035..b4b90fbfc5 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -11,6 +11,7 @@ py_library( deps = [ "//tensorflow_addons/activations", "//tensorflow_addons/testing", + "//tensorflow_addons/text", "//tensorflow_addons/utils", ], ) diff --git a/tensorflow_addons/layers/crf.py b/tensorflow_addons/layers/crf.py new file mode 100644 index 0000000000..a71fc80f9e --- /dev/null +++ b/tensorflow_addons/layers/crf.py @@ -0,0 +1,244 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Orginal implementation from keras_contrib/layers/crf +# ============================================================================== +"""Implementing Conditional Random Field layer.""" + +import tensorflow as tf +from typeguard import typechecked + +from tensorflow_addons.text.crf import crf_decode +from tensorflow_addons.utils import types + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class CRF(tf.keras.layers.Layer): + """Linear chain conditional random field (CRF). + + References: + - [Conditional Random Field](https://en.wikipedia.org/wiki/Conditional_random_field) + """ + + @typechecked + def __init__( + self, + units: int, + chain_initializer: types.Initializer = "orthogonal", + use_boundary: bool = True, + boundary_initializer: types.Initializer = "zeros", + use_kernel: bool = True, + **kwargs + ): + super().__init__(**kwargs) + + # setup mask supporting flag, used by base class (the Layer) + # because base class's init method will set it to False unconditionally + # So this assigned must be executed after call base class's init method + self.supports_masking = True + + self.units = units # numbers of tags + + self.use_boundary = use_boundary + self.use_kernel = use_kernel + self.chain_initializer = tf.keras.initializers.get(chain_initializer) + self.boundary_initializer = tf.keras.initializers.get(boundary_initializer) + + # weights that work as transfer probability of each tags + self.chain_kernel = self.add_weight( + shape=(self.units, self.units), + name="chain_kernel", + initializer=self.chain_initializer, + ) + + # weight of to tag probability and tag to probability + if self.use_boundary: + self.left_boundary = self.add_weight( + shape=(self.units,), + name="left_boundary", + initializer=self.boundary_initializer, + ) + self.right_boundary = self.add_weight( + shape=(self.units,), + name="right_boundary", + initializer=self.boundary_initializer, + ) + + if self.use_kernel: + self._dense_layer = tf.keras.layers.Dense( + units=self.units, dtype=self.dtype, + ) + else: + self._dense_layer = lambda x: tf.cast(x, dtype=self.dtype) + + def call(self, inputs, mask=None): + # mask: Tensor(shape=(batch_size, sequence_length), dtype=bool) or None + + if mask is not None: + if tf.keras.backend.ndim(mask) != 2: + raise ValueError("Input mask to CRF must have dim 2 if not None") + + if mask is not None: + # left padding of mask is not supported, due the underline CRF function + # detect it and report it to user + left_boundary_mask = self._compute_mask_left_boundary(mask) + first_mask = left_boundary_mask[:, 0] + if first_mask is not None and tf.executing_eagerly(): + no_left_padding = tf.math.reduce_all(first_mask) + left_padding = not no_left_padding + if left_padding: + raise NotImplementedError( + "Currently, CRF layer do not support left padding" + ) + + potentials = self._dense_layer(inputs) + + # appending boundary probability info + if self.use_boundary: + potentials = self.add_boundary_energy( + potentials, mask, self.left_boundary, self.right_boundary + ) + + sequence_length = self._get_sequence_length(inputs, mask) + + decoded_sequence, _ = self.get_viterbi_decoding(potentials, sequence_length) + + return [decoded_sequence, potentials, sequence_length, self.chain_kernel] + + def _get_sequence_length(self, input_, mask): + """Currently underline CRF fucntion (provided by + tensorflow_addons.text.crf) do not support bi-direction masking (left + padding / right padding), it support right padding by tell it the + sequence length. + + this function is compute the sequence length from input and + mask. + """ + if mask is not None: + sequence_length = self.mask_to_sequence_length(mask) + else: + # make a mask tensor from input, then used to generate sequence_length + input_energy_shape = tf.shape(input_) + raw_input_shape = tf.slice(input_energy_shape, [0], [2]) + alt_mask = tf.ones(raw_input_shape) + + sequence_length = self.mask_to_sequence_length(alt_mask) + + return sequence_length + + def mask_to_sequence_length(self, mask): + """compute sequence length from mask.""" + sequence_length = tf.cast(tf.reduce_sum(tf.cast(mask, tf.int8), 1), tf.int64) + return sequence_length + + @staticmethod + def _compute_mask_right_boundary(mask): + """input mask: 0011100, output left_boundary: 0000100.""" + # shift mask to left by 1: 0011100 => 0111000 + offset = 1 + left_shifted_mask = tf.concat( + [mask[:, offset:], tf.zeros_like(mask[:, :offset])], axis=1 + ) + + # NOTE: below code is different from keras_contrib + # Original code in keras_contrib: + # end_mask = K.cast( + # K.greater(self.shift_left(mask), mask), + # K.floatx() + # ) + # has a bug, confirmed + # by the original keras_contrib maintainer + # Luiz Felix (github: lzfelix), + + # 0011100 > 0111000 => 0000100 + right_boundary = tf.greater(mask, left_shifted_mask) + + return right_boundary + + @staticmethod + def _compute_mask_left_boundary(mask): + """input mask: 0011100, output left_boundary: 0010000.""" + # shift mask to right by 1: 0011100 => 0001110 + offset = 1 + right_shifted_mask = tf.concat( + [tf.zeros_like(mask[:, :offset]), mask[:, :-offset]], axis=1 + ) + + # 0011100 > 0001110 => 0010000 + left_boundary = tf.greater( + tf.cast(mask, tf.int32), tf.cast(right_shifted_mask, tf.int32) + ) + # left_boundary = tf.greater(mask, right_shifted_mask) + + return left_boundary + + def add_boundary_energy(self, potentials, mask, start, end): + def expand_scalar_to_3d(x): + # expand tensor from shape (x, ) to (1, 1, x) + return tf.reshape(x, (1, 1, -1)) + + start = expand_scalar_to_3d(start) + end = expand_scalar_to_3d(end) + if mask is None: + potentials = tf.concat( + [potentials[:, :1, :] + start, potentials[:, 1:, :]], axis=1 + ) + potentials = tf.concat( + [potentials[:, :-1, :], potentials[:, -1:, :] + end], axis=1 + ) + else: + mask = tf.keras.backend.expand_dims(tf.cast(mask, start.dtype), axis=-1) + start_mask = tf.cast(self._compute_mask_left_boundary(mask), start.dtype) + + end_mask = tf.cast(self._compute_mask_right_boundary(mask), end.dtype) + potentials = potentials + start_mask * start + potentials = potentials + end_mask * end + return potentials + + def get_viterbi_decoding(self, potentials, sequence_length): + # decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32` + decode_tags, best_score = crf_decode( + potentials, self.chain_kernel, sequence_length + ) + + return decode_tags, best_score + + def get_config(self): + # used for loading model from disk + config = { + "units": self.units, + "chain_initializer": tf.keras.initializers.serialize( + self.chain_initializer + ), + "use_boundary": self.use_boundary, + "boundary_initializer": tf.keras.initializers.serialize( + self.boundary_initializer + ), + "use_kernel": self.use_kernel, + } + base_config = super().get_config() + return {**base_config, **config} + + def compute_output_shape(self, input_shape): + output_shape = input_shape[:2] + return output_shape + + def compute_mask(self, input_, mask=None): + """keep mask shape [batch_size, max_seq_len]""" + return mask + + @property + def _compute_dtype(self): + # fixed output dtype from underline CRF functions + return tf.int32 diff --git a/tensorflow_addons/layers/tests/crf_test.py b/tensorflow_addons/layers/tests/crf_test.py new file mode 100644 index 0000000000..ffb3e0442f --- /dev/null +++ b/tensorflow_addons/layers/tests/crf_test.py @@ -0,0 +1,329 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Conditional Random Field layer.""" + +import itertools +import os +import math +import tempfile + +import pytest +import numpy as np +import tensorflow as tf + +from tensorflow_addons.layers.crf import CRF +from tensorflow_addons.text.crf import crf_log_likelihood + + +def get_test_data(): + x = np.array( + [ + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + ], + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ) + y = np.array([[1, 2, 2], [1, 1, 1]]) # B-X I-X I-X # B-X B-X B-X + return x, y + + +def get_test_data_extended(): + logits = np.array( + [ + [[0, 0, 0.5, 0.5, 0.2], [0, 0, 0.3, 0.3, 0.1], [0, 0, 0.9, 10, 1]], + [[0, 0, 0.2, 0.5, 0.2], [0, 0, 3, 0.3, 0.1], [0, 0, 0.9, 1, 1]], + ] + ) + tags = np.array([[2, 3, 4], [3, 2, 2]]) + + transitions = np.array( + [ + [0.1, 0.2, 0.3, 0.4, 0.5], + [0.8, 0.3, 0.1, 0.7, 0.9], + [-0.3, 2.1, -5.6, 3.4, 4.0], + [0.2, 0.4, 0.6, -0.3, -0.4], + [1.0, 1.0, 1.0, 1.0, 1.0], + ] + ) + + boundary_values = np.ones((5,)) + crf_layer = CRF( + units=5, + use_kernel=False, # disable kernel transform + chain_initializer=tf.keras.initializers.Constant(transitions), + use_boundary=True, + boundary_initializer=tf.keras.initializers.Constant(boundary_values), + name="crf_layer", + ) + return logits, tags, transitions, boundary_values, crf_layer + + +def test_keras_model_inference(): + logits, _, _, _, crf_layer = get_test_data_extended() + + input_tensor = tf.keras.layers.Input(shape=(3, 5)) + decoded_sequence, _, _, _ = crf_layer(input_tensor) + model = tf.keras.Model(input_tensor, decoded_sequence) + + model.predict(logits) + model(logits).numpy() + + +def test_unmasked_viterbi_decode(): + + x_np, y_np = get_test_data() + + transitions = np.ones([5, 5]) + boundary_value = np.ones(5) + + layer = CRF( + units=5, + use_kernel=False, # disable kernel transform + chain_initializer=tf.keras.initializers.Constant(transitions), + use_boundary=True, + boundary_initializer=tf.keras.initializers.Constant(boundary_value), + ) + + decoded_sequence, _, _, _ = layer(x_np) + decoded_sequence = decoded_sequence.numpy() + np.testing.assert_equal(decoded_sequence, y_np) + assert decoded_sequence.dtype == np.int32 + + +def unpack_data(data): + if len(data) == 2: + return data[0], data[1], None + elif len(data) == 3: + return data + else: + raise TypeError("Expected data to be a tuple of size 2 or 3.") + + +class ModelWithCRFLoss(tf.keras.Model): + """Wrapper around the base model for custom training logic.""" + + def __init__(self, base_model): + super().__init__() + self.base_model = base_model + + def call(self, inputs): + return self.base_model(inputs) + + def compute_loss(self, x, y, sample_weights, training=False): + y_pred = self(x, training=training) + _, potentials, sequence_length, chain_kernel = y_pred + + # we now add the CRF loss: + crf_loss = -crf_log_likelihood(potentials, y, sequence_length, chain_kernel)[0] + + if sample_weights is not None: + crf_loss = crf_loss * sample_weights + + return tf.reduce_mean(crf_loss), sum(self.losses) + + def train_step(self, data): + x, y, sample_weight = unpack_data(data) + + with tf.GradientTape() as tape: + crf_loss, internal_losses = self.compute_loss( + x, y, sample_weight, training=True + ) + total_loss = crf_loss + internal_losses + + gradients = tape.gradient(total_loss, self.trainable_variables) + self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + + return {"crf_loss": crf_loss, "internal_losses": internal_losses} + + def test_step(self, data): + x, y, sample_weight = unpack_data(data) + crf_loss, internal_losses = self.compute_loss(x, y, sample_weight) + return {"crf_loss_val": crf_loss, "internal_losses_val": internal_losses} + + +def test_traing(): + x_np, y_np = get_test_data() + get_some_model(x_np, y_np) + + +def get_some_model(x_np, y_np, sanity_check=True): + x_input = tf.keras.layers.Input(shape=x_np.shape[1:]) + crf_outputs = CRF(5, name="L")(x_input) + base_model = tf.keras.Model(x_input, crf_outputs) + + model = ModelWithCRFLoss(base_model) + + model.compile("adam") + if sanity_check: + model.fit(x=x_np, y=y_np) + model.evaluate(x_np, y_np) + model.predict(x_np) + return model + + +def test_mask_right_padding(): + x_np, y_np = get_test_data() + mask = np.array([[1, 1, 1], [1, 1, 0]]) + + x = tf.keras.layers.Input(shape=x_np.shape[1:]) + + crf_layer_outputs = CRF(5)(x, mask=tf.constant(mask)) + + base_model = tf.keras.Model(x, crf_layer_outputs) + model = ModelWithCRFLoss(base_model) + + # check shape inference + model.compile("adam") + old_weights = model.get_weights() + model.fit(x_np, y_np) + new_weights = model.get_weights() + + # we check that the weights were updated during the training phase. + with pytest.raises(AssertionError): + assert_all_equal(old_weights, new_weights) + + model.predict(x_np) + + +def test_mask_left_padding(): + x_np, y_np = get_test_data() + mask = np.array([[0, 1, 1], [1, 1, 1]]) + + x = tf.keras.layers.Input(shape=x_np.shape[1:]) + crf_layer_outputs = CRF(5)(x, mask=tf.constant(mask)) + + base_model = tf.keras.Model(x, crf_layer_outputs) + model = ModelWithCRFLoss(base_model) + + # we can only check the value of the mask + # if we run eagerly. It's kind of a debug mode + # otherwise we're wasting computation. + model.compile("adam", run_eagerly=True) + + with pytest.raises(NotImplementedError) as context: + model(x_np).numpy() + + assert "CRF layer do not support left padding" in str(context.value) + + +def clone(model: ModelWithCRFLoss, inference_only=True): + + with tempfile.TemporaryDirectory() as tmpdir: + file_path = os.path.join(tmpdir, "my_model.tf") + model.save(file_path) + new_model = tf.keras.models.load_model(file_path) + + if not inference_only: + # since tf doesn't save the python code of train_step and test_step + # we need to call the wrapper again. + # This may change, maybe later on tf will save train_step and test_step. + new_model_with_wrapper = ModelWithCRFLoss(new_model.base_model) + + # this works, but it may be cleaner to do a copy of the optimizer + new_model_with_wrapper.compile(optimizer=new_model.optimizer) + new_model = new_model_with_wrapper + + return new_model + + +def assert_all_equal(array_list1, array_list2): + for arr1, arr2 in zip(array_list1, array_list2): + np.testing.assert_equal(np.array(arr1), np.array(arr2)) + + +@pytest.mark.parametrize("inference_only", [True, False]) +def test_serialization(inference_only): + + x_np, y_np = get_test_data() + model = get_some_model(x_np, y_np, sanity_check=False) + + new_model = clone(model, inference_only) + if inference_only: + assert_all_equal(model.predict(x_np), new_model.predict(x_np)) + assert_all_equal(model.get_weights(), new_model.get_weights()) + else: + original_loss = model.train_on_batch(x_np, y_np, return_dict=True)["crf_loss"] + clone_loss = new_model.train_on_batch(x_np, y_np, return_dict=True)["crf_loss"] + assert_all_equal(model.get_weights(), new_model.get_weights()) + assert original_loss == clone_loss + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_numerical_accuracy(): + logits, tags, transitions, boundary_values, crf_layer = get_test_data_extended() + + x_input = tf.keras.layers.Input(shape=logits.shape[1:]) + crf_outputs = crf_layer(x_input) + base_model = tf.keras.Model(x_input, crf_outputs) + model = ModelWithCRFLoss(base_model) + + model.compile() + log_likelihood = model.train_on_batch(logits, tags, return_dict=True)["crf_loss"] + + # The manually computed log likelihood should + # equal the result of crf.forward. + expected_log_likelihood = compute_log_likelihood( + logits, tags, transitions, boundary_values + ) + unbatched_log_likelihood = -2 * log_likelihood + + np.testing.assert_allclose( + expected_log_likelihood, unbatched_log_likelihood, rtol=2e-7 + ) + + +def compute_log_likelihood(logits, tags, transitions, boundary_values): + # Now compute the log-likelihood manually + manual_log_likelihood = 0.0 + + # For each instance, manually compute the numerator + # (which is just the score for the logits and actual tags) + # and the denominator + # (which is the log-sum-exp of the scores + # for the logits across all possible tags) + for logits_i, tags_i in zip(logits, tags): + numerator = score_logits(logits_i, tags_i, transitions, boundary_values) + all_scores = [ + score_logits(logits_i, tags_j, transitions, boundary_values) + for tags_j in itertools.product(range(5), repeat=3) + ] + denominator = math.log(sum(math.exp(score) for score in all_scores)) + # And include them in the manual calculation. + manual_log_likelihood += numerator - denominator + + return manual_log_likelihood + + +def score_logits(logits, tags, transitions, boundary_values): + """Computes the likelihood score for the given sequence of tags, given + the provided logits (and the transition weights in the CRF model)""" + # Start with transitions from START and to END + total = boundary_values[tags[0]] + boundary_values[tags[-1]] + # Add in all the intermediate transitions + for tag, next_tag in zip(tags, tags[1:]): + total += transitions[tag, next_tag] + # Add in the logits for the observed tags + for logit, tag in zip(logits, tags): + total += logit[tag] + return total