diff --git a/tensorflow_addons/text/BUILD b/tensorflow_addons/text/BUILD index 4787cd8c0c..21306ef3f9 100644 --- a/tensorflow_addons/text/BUILD +++ b/tensorflow_addons/text/BUILD @@ -6,6 +6,7 @@ py_library( name = "text", srcs = ([ "__init__.py", + "crf.py", "skip_gram_ops.py", ]), data = [ @@ -15,6 +16,19 @@ py_library( srcs_version = "PY2AND3", ) +py_test( + name = "crf_test", + size = "small", + srcs = [ + "crf_test.py", + ], + main = "crf_test.py", + srcs_version = "PY2AND3", + deps = [ + ":text", + ], +) + py_test( name = "skip_gram_ops_test", size = "small", diff --git a/tensorflow_addons/text/README.md b/tensorflow_addons/text/README.md index 4b4d948363..d6e60a07b9 100644 --- a/tensorflow_addons/text/README.md +++ b/tensorflow_addons/text/README.md @@ -4,6 +4,7 @@ | Submodule | Maintainers | Contact Info | |:---------- |:----------- |:------------- | | skip_gram_ops | | | +| crf | Dheeraj R. Reddy | dheeraj98reddy@gmail.com | ## Components | Submodule | Text Processing Function | Reference | diff --git a/tensorflow_addons/text/__init__.py b/tensorflow_addons/text/__init__.py index 05c758e26d..11f8f9fecb 100644 --- a/tensorflow_addons/text/__init__.py +++ b/tensorflow_addons/text/__init__.py @@ -17,6 +17,19 @@ from __future__ import division from __future__ import print_function +# Conditional Random Field +from tensorflow_addons.text.crf import crf_binary_score +from tensorflow_addons.text.crf import crf_decode +from tensorflow_addons.text.crf import crf_decode_backward +from tensorflow_addons.text.crf import crf_decode_forward +from tensorflow_addons.text.crf import crf_forward +from tensorflow_addons.text.crf import crf_log_likelihood +from tensorflow_addons.text.crf import crf_log_norm +from tensorflow_addons.text.crf import crf_multitag_sequence_score +from tensorflow_addons.text.crf import crf_sequence_score +from tensorflow_addons.text.crf import crf_unary_score +from tensorflow_addons.text.crf import viterbi_decode + # Skip Gram Sampling from tensorflow_addons.text.skip_gram_ops import skip_gram_sample -from tensorflow_addons.text.skip_gram_ops import skip_gram_sample_with_text_vocab +from tensorflow_addons.text.skip_gram_ops import skip_gram_sample_with_text_vocab \ No newline at end of file diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py new file mode 100644 index 0000000000..d8d97bf216 --- /dev/null +++ b/tensorflow_addons/text/crf.py @@ -0,0 +1,488 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +# TODO: Wrap functions in @tf.function once +# https://github.com/tensorflow/tensorflow/issues/29075 is resolved + + +def crf_sequence_score(inputs, tag_indices, sequence_lengths, + transition_params): + """Computes the unnormalized score for a tag sequence. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which + we compute the unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of the single tag. + def _single_seq_fn(): + batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0] + + example_inds = tf.reshape( + tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) + sequence_scores = tf.gather_nd( + tf.squeeze(inputs, [1]), + tf.concat([example_inds, tag_indices], axis=1)) + sequence_scores = tf.where( + tf.less_equal(sequence_lengths, 0), tf.zeros_like(sequence_scores), + sequence_scores) + return sequence_scores + + def _multi_seq_fn(): + # Compute the scores of the given tag sequence. + unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) + binary_scores = crf_binary_score(tag_indices, sequence_lengths, + transition_params) + sequence_scores = unary_scores + binary_scores + return sequence_scores + + if inputs.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() + + +def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, + transition_params): + """Computes the unnormalized score of all tag sequences matching + tag_bitmap. + + tag_bitmap enables more than one tag to be considered correct at each time + step. This is useful when an observed output at a given time step is + consistent with more than one tag, and thus the log likelihood of that + observation must take into account all possible consistent tags. + + Using one-hot vectors in tag_bitmap gives results identical to + crf_sequence_score. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor + representing all active tags at each index for which to calculate the + unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of all active tags. + def _single_seq_fn(): + filtered_inputs = tf.where(tag_bitmap, inputs, + tf.fill(tf.shape(inputs), float("-inf"))) + return tf.reduce_logsumexp( + filtered_inputs, axis=[1, 2], keepdims=False) + + def _multi_seq_fn(): + # Compute the logsumexp of all scores of sequences matching the given tags. + filtered_inputs = tf.where(tag_bitmap, inputs, + tf.fill(tf.shape(inputs), float("-inf"))) + return crf_log_norm( + inputs=filtered_inputs, + sequence_lengths=sequence_lengths, + transition_params=transition_params) + + if inputs.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() + + +def crf_log_norm(inputs, sequence_lengths, transition_params): + """Computes the normalization for a CRF. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + log_norm: A [batch_size] vector of normalizers for a CRF. + """ + # Split up the first and rest of the inputs in preparation for the forward + # algorithm. + first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1]) + first_input = tf.squeeze(first_input, [1]) + + # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over + # the "initial state" (the unary potentials). + def _single_seq_fn(): + log_norm = tf.reduce_logsumexp(first_input, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = tf.where( + tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), + log_norm) + return log_norm + + def _multi_seq_fn(): + """Forward computation of alpha values.""" + rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1]) + # Compute the alpha values in the forward algorithm in order to get the + # partition function. + + alphas = crf_forward(rest_of_input, first_input, transition_params, + sequence_lengths) + log_norm = tf.reduce_logsumexp(alphas, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = tf.where( + tf.less_equal(sequence_lengths, 0), tf.zeros_like(log_norm), + log_norm) + return log_norm + + if inputs.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() + + +def crf_log_likelihood(inputs, + tag_indices, + sequence_lengths, + transition_params=None): + """Computes the log-likelihood of tag sequences in a CRF. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which + we compute the log-likelihood. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix, + if available. + Returns: + log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of + each example, given the sequence of tag indices. + transition_params: A [num_tags, num_tags] transition matrix. This is + either provided by the caller or created in this function. + """ + # Get shape information. + num_tags = inputs.shape[2] + + # Get the transition matrix if not provided. + if transition_params is None: + transition_params = tf.get_variable("transitions", + [num_tags, num_tags]) + + sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths, + transition_params) + log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) + + # Normalize the scores to get the log-likelihood per example. + log_likelihood = sequence_scores - log_norm + return log_likelihood, transition_params + + +def crf_unary_score(tag_indices, sequence_lengths, inputs): + """Computes the unary scores of tag sequences. + + Args: + tag_indices: A [batch_size, max_seq_len] matrix of tag indices. + sequence_lengths: A [batch_size] vector of true sequence lengths. + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. + Returns: + unary_scores: A [batch_size] vector of unary scores. + """ + batch_size = tf.shape(inputs)[0] + max_seq_len = tf.shape(inputs)[1] + num_tags = tf.shape(inputs)[2] + + flattened_inputs = tf.reshape(inputs, [-1]) + + offsets = tf.expand_dims(tf.range(batch_size) * max_seq_len * num_tags, 1) + offsets += tf.expand_dims(tf.range(max_seq_len) * num_tags, 0) + # Use int32 or int64 based on tag_indices' dtype. + if tag_indices.dtype == tf.int64: + offsets = tf.cast(offsets, tf.int64) + flattened_tag_indices = tf.reshape(offsets + tag_indices, [-1]) + + unary_scores = tf.reshape( + tf.gather(flattened_inputs, flattened_tag_indices), + [batch_size, max_seq_len]) + + masks = tf.sequence_mask( + sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) + + unary_scores = tf.reduce_sum(unary_scores * masks, 1) + return unary_scores + + +def crf_binary_score(tag_indices, sequence_lengths, transition_params): + """Computes the binary scores of tag sequences. + + Args: + tag_indices: A [batch_size, max_seq_len] matrix of tag indices. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + Returns: + binary_scores: A [batch_size] vector of binary scores. + """ + # Get shape information. + num_tags = tf.shape(transition_params)[0] + num_transitions = tf.shape(tag_indices)[1] - 1 + + # Truncate by one on each side of the sequence to get the start and end + # indices of each transition. + start_tag_indices = tf.slice(tag_indices, [0, 0], [-1, num_transitions]) + end_tag_indices = tf.slice(tag_indices, [0, 1], [-1, num_transitions]) + + # Encode the indices in a flattened representation. + flattened_transition_indices = start_tag_indices * \ + num_tags + end_tag_indices + flattened_transition_params = tf.reshape(transition_params, [-1]) + + # Get the binary scores based on the flattened representation. + binary_scores = tf.gather(flattened_transition_params, + flattened_transition_indices) + + masks = tf.sequence_mask( + sequence_lengths, maxlen=tf.shape(tag_indices)[1], dtype=tf.float32) + truncated_masks = tf.slice(masks, [0, 1], [-1, -1]) + binary_scores = tf.reduce_sum(binary_scores * truncated_masks, 1) + return binary_scores + + +def crf_forward(inputs, state, transition_params, sequence_lengths): + """Computes the alpha values in a linear-chain CRF. + + See http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous alpha + values. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + This matrix is expanded into a [1, num_tags, num_tags] in preparation + for the broadcast summation occurring within the cell. + sequence_lengths: A [batch_size] vector of true sequence lengths. + + Returns: + new_alphas: A [batch_size, num_tags] matrix containing the + new alpha values. + """ + + sequence_lengths = tf.maximum( + tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 2) + inputs = tf.transpose(inputs, [1, 0, 2]) + transition_params = tf.expand_dims(transition_params, 0) + + def _scan_fn(state, inputs): + state = tf.expand_dims(state, 2) + transition_scores = state + transition_params + new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1]) + return new_alphas + + all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) + idxs = tf.stack( + [tf.range(tf.shape(sequence_lengths)[0]), sequence_lengths], axis=1) + return tf.gather_nd(all_alphas, idxs) + + +def viterbi_decode(score, transition_params): + """Decode the highest scoring sequence of tags outside of TensorFlow. + + This should only be used at test time. + + Args: + score: A [seq_len, num_tags] matrix of unary potentials. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + + Returns: + viterbi: A [seq_len] list of integers containing the highest scoring tag + indices. + viterbi_score: A float containing the score for the Viterbi sequence. + """ + trellis = np.zeros_like(score) + backpointers = np.zeros_like(score, dtype=np.int32) + trellis[0] = score[0] + + for t in range(1, score.shape[0]): + v = np.expand_dims(trellis[t - 1], 1) + transition_params + trellis[t] = score[t] + np.max(v, 0) + backpointers[t] = np.argmax(v, 0) + + viterbi = [np.argmax(trellis[-1])] + for bp in reversed(backpointers[1:]): + viterbi.append(bp[viterbi[-1]]) + viterbi.reverse() + + viterbi_score = np.max(trellis[-1]) + return viterbi, viterbi_score + + +class CrfDecodeForwardRnnCell(tf.keras.layers.AbstractRNNCell): + """Computes the forward decoding in a linear-chain CRF.""" + + def __init__(self, transition_params, **kwargs): + """Initialize the CrfDecodeForwardRnnCell. + + Args: + transition_params: A [num_tags, num_tags] matrix of binary + potentials. This matrix is expanded into a + [1, num_tags, num_tags] in preparation for the broadcast + summation occurring within the cell. + """ + super(CrfDecodeForwardRnnCell, self).__init__(**kwargs) + self._transition_params = tf.expand_dims(transition_params, 0) + self._num_tags = transition_params.shape[0] + + @property + def state_size(self): + return self._num_tags + + @property + def output_size(self): + return self._num_tags + + def build(self, input_shape): + super(CrfDecodeForwardRnnCell, self).build(input_shape) + + def call(self, inputs, state): + """Build the CrfDecodeForwardRnnCell. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous step's + score values. + + Returns: + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. + """ + state = tf.expand_dims(state[0], 2) + transition_scores = state + self._transition_params + new_state = inputs + tf.reduce_max(transition_scores, [1]) + backpointers = tf.argmax(transition_scores, 1) + backpointers = tf.cast(backpointers, dtype=tf.int32) + return backpointers, new_state + + +def crf_decode_forward(inputs, state, transition_params, sequence_lengths): + """Computes forward decoding in a linear-chain CRF. + + Args: + inputs: A [batch_size, num_tags] matrix of unary potentials. + state: A [batch_size, num_tags] matrix containing the previous step's + score values. + transition_params: A [num_tags, num_tags] matrix of binary potentials. + sequence_lengths: A [batch_size] vector of true sequence lengths. + + Returns: + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. + """ + mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + crf_fwd_layer = tf.keras.layers.RNN( + crf_fwd_cell, return_sequences=True, return_state=True) + return crf_fwd_layer(inputs, state, mask=mask) + + +def crf_decode_backward(inputs, state): + """Computes backward decoding in a linear-chain CRF. + + Args: + inputs: A [batch_size, num_tags] matrix of + backpointer of next step (in time order). + state: A [batch_size, 1] matrix of tag index of next step. + + Returns: + new_tags: A [batch_size, num_tags] + tensor containing the new tag indices. + """ + inputs = tf.transpose(inputs, [1, 0, 2]) + + def _scan_fn(state, inputs): + state = tf.squeeze(state, axis=[1]) + idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1) + new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1) + return new_tags + + return tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2]) + + +def crf_decode(potentials, transition_params, sequence_length): + """Decode the highest scoring sequence of tags in TensorFlow. + + This is a function for tensor. + + Args: + potentials: A [batch_size, max_seq_len, num_tags] tensor of + unary potentials. + transition_params: A [num_tags, num_tags] matrix of + binary potentials. + sequence_length: A [batch_size] vector of true sequence lengths. + + Returns: + decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. + Contains the highest scoring tag indices. + best_score: A [batch_size] vector, containing the score of `decode_tags`. + """ + + # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag + # and the max activation. + def _single_seq_fn(): + squeezed_potentials = tf.squeeze(potentials, [1]) + decode_tags = tf.expand_dims(tf.argmax(squeezed_potentials, axis=1), 1) + best_score = tf.reduce_max(squeezed_potentials, axis=1) + return tf.cast(decode_tags, dtype=tf.int32), best_score + + def _multi_seq_fn(): + """Decoding of highest scoring sequence.""" + # Computes forward decoding. Get last score and backpointers. + initial_state = tf.slice(potentials, [0, 0, 0], [-1, 1, -1]) + initial_state = tf.squeeze(initial_state, axis=[1]) + inputs = tf.slice(potentials, [0, 1, 0], [-1, -1, -1]) + + sequence_length_less_one = tf.maximum( + tf.constant(0, dtype=sequence_length.dtype), sequence_length - 1) + + backpointers, last_score = crf_decode_forward( + inputs, initial_state, transition_params, sequence_length_less_one) + + backpointers = tf.reverse_sequence( + backpointers, sequence_length_less_one, seq_axis=1) + + initial_state = tf.cast(tf.argmax(last_score, axis=1), dtype=tf.int32) + initial_state = tf.expand_dims(initial_state, axis=-1) + + decode_tags = crf_decode_backward(backpointers, initial_state) + decode_tags = tf.squeeze(decode_tags, axis=[2]) + decode_tags = tf.concat([initial_state, decode_tags], axis=1) + decode_tags = tf.reverse_sequence( + decode_tags, sequence_length, seq_axis=1) + + best_score = tf.reduce_max(last_score, axis=1) + return decode_tags, best_score + + if potentials.shape[1] == 1: + return _single_seq_fn() + else: + return _multi_seq_fn() diff --git a/tensorflow_addons/text/crf_test.py b/tensorflow_addons/text/crf_test.py new file mode 100644 index 0000000000..84c09b539b --- /dev/null +++ b/tensorflow_addons/text/crf_test.py @@ -0,0 +1,346 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CRF.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np +import tensorflow as tf + +from tensorflow_addons import text +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class CrfTest(tf.test.TestCase): + def calculateSequenceScore(self, inputs, transition_params, tag_indices, + sequence_lengths): + expected_unary_score = sum( + inputs[i][tag_indices[i]] for i in range(sequence_lengths)) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + return expected_unary_score + expected_binary_score + + def testCrfSequenceScore(self): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([1], dtype=np.int32) + ] + for sequence_lengths, inputs, tag_indices in zip( + sequence_lengths_list, inputs_list, tag_indices_list): + sequence_score = text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + + tf_sequence_score = self.evaluate(sequence_score) + + expected_sequence_score = self.calculateSequenceScore( + inputs, transition_params, tag_indices, sequence_lengths) + self.assertAllClose(tf_sequence_score, expected_sequence_score) + + def testCrfMultiTagSequenceScore(self): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], dtype=np.float32), + ] + tag_bitmap_list = [ + np.array([[True, True, False], [True, False, True], + [False, True, True], [True, False, True]], + dtype=np.bool), + np.array([[True, True, False]], dtype=np.bool) + ] + for sequence_lengths, inputs, tag_bitmap in zip( + sequence_lengths_list, inputs_list, tag_bitmap_list): + sequence_score = text.crf_multitag_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_bitmap=tf.expand_dims(tag_bitmap, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + tf_sum_sequence_score = self.evaluate(sequence_score) + all_indices_list = [ + single_index_bitmap.nonzero()[0] + for single_index_bitmap in tag_bitmap[:sequence_lengths] + ] + expected_sequence_scores = [ + self.calculateSequenceScore(inputs, transition_params, indices, + sequence_lengths) + for indices in itertools.product(*all_indices_list) + ] + expected_log_sum_exp_sequence_scores = np.logaddexp.reduce( + expected_sequence_scores) + self.assertAllClose(tf_sum_sequence_score, + expected_log_sum_exp_sequence_scores) + + def testCrfUnaryScore(self): + inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32) + for dtype in (np.int32, np.int64): + tag_indices = np.array([1, 2, 1, 0], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + unary_score = text.crf_unary_score( + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + inputs=tf.expand_dims(inputs, 0)) + unary_score = tf.squeeze(unary_score, [0]) + tf_unary_score = self.evaluate(unary_score) + expected_unary_score = sum( + inputs[i][tag_indices[i]] for i in range(sequence_lengths)) + self.assertAllClose(tf_unary_score, expected_unary_score) + + def testCrfBinaryScore(self): + tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + binary_score = text.crf_binary_score( + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + binary_score = tf.squeeze(binary_score, [0]) + tf_binary_score = self.evaluate(binary_score) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + self.assertAllClose(tf_binary_score, expected_binary_score) + + def testCrfLogNorm(self): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int64) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[3, -1, 3]], dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip( + sequence_lengths_list, inputs_list, tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + all_sequence_scores = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequence_scores.append( + text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params))) + + brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores) + log_norm = text.crf_log_norm( + inputs=tf.expand_dims(inputs, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + log_norm = tf.squeeze(log_norm, [0]) + tf_brute_force_log_norm, tf_log_norm = self.evaluate( + [brute_force_log_norm, log_norm]) + + self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + + def testCrfLogNormZeroSeqLength(self): + """Test `crf_log_norm` when `sequence_lengths` contains one or more + zeros.""" + inputs = tf.constant(np.ones([2, 10, 5], dtype=np.float32)) + transition_params = tf.constant(np.ones([5, 5], dtype=np.float32)) + sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32)) + expected_log_norm = np.zeros([2], dtype=np.float32) + log_norm = text.crf_log_norm(inputs, sequence_lengths, + transition_params) + tf_log_norm = self.evaluate(log_norm) + self.assertAllClose(tf_log_norm, expected_log_norm) + + def testCrfLogLikelihood(self): + inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32) + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + all_sequence_log_likelihoods = [] + + # Make sure all probabilities sum to 1. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + sequence_log_likelihood, _ = text.crf_log_likelihood( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + all_sequence_log_likelihoods.append(sequence_log_likelihood) + total_log_likelihood = tf.reduce_logsumexp( + all_sequence_log_likelihoods) + tf_total_log_likelihood = self.evaluate(total_log_likelihood) + self.assertAllClose(tf_total_log_likelihood, 0.0) + + def testViterbiDecode(self): + inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32) + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + sequence_lengths = np.array(3, dtype=np.int32) + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = self.evaluate(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[ + expected_max_sequence_index] + + actual_max_sequence, actual_max_score = text.viterbi_decode( + inputs[:sequence_lengths], transition_params) + + self.assertAllClose(actual_max_score, expected_max_score) + self.assertEqual(actual_max_sequence, + expected_max_sequence[:sequence_lengths]) + + def testCrfDecode(self): + transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], + dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int64) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[-1, 2, 1]], dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip( + sequence_lengths_list, inputs_list, tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = text.crf_sequence_score( + inputs=tf.expand_dims(inputs, 0), + tag_indices=tf.expand_dims(tag_indices, 0), + sequence_lengths=tf.expand_dims(sequence_lengths, 0), + transition_params=tf.constant(transition_params)) + sequence_score = tf.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = self.evaluate(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[ + expected_max_sequence_index] + + actual_max_sequence, actual_max_score = text.crf_decode( + tf.expand_dims(inputs, 0), tf.constant(transition_params), + tf.expand_dims(sequence_lengths, 0)) + actual_max_sequence = tf.squeeze(actual_max_sequence, [0]) + actual_max_score = tf.squeeze(actual_max_score, [0]) + tf_actual_max_sequence, tf_actual_max_score = self.evaluate( + [actual_max_sequence, actual_max_score]) + + self.assertAllClose(tf_actual_max_score, expected_max_score) + self.assertEqual( + list(tf_actual_max_sequence[:sequence_lengths]), + expected_max_sequence[:sequence_lengths]) + + def testCrfDecodeZeroSeqLength(self): + """Test that crf_decode works when sequence_length contains one or more + zeros.""" + inputs = tf.constant(np.ones([2, 10, 5], dtype=np.float32)) + transition_params = tf.constant(np.ones([5, 5], dtype=np.float32)) + sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32)) + tags, scores = text.crf_decode(inputs, transition_params, + sequence_lengths) + tf_tags, tf_scores = self.evaluate([tags, scores]) + self.assertEqual(len(tf_tags.shape), 2) + self.assertEqual(len(tf_scores.shape), 1) + + +if __name__ == "__main__": + tf.test.main()