diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index 8bc9901128..63f0e6ae6e 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -443,8 +443,10 @@ def call(self, inputs, state): new_state: A [batch_size, num_tags] matrix of new score values. """ state = tf.expand_dims(state[0], 2) - transition_scores = state + tf.cast(self._transition_params, state.dtype) - new_state = tf.cast(inputs, state.dtype) + tf.reduce_max(transition_scores, [1]) + transition_scores = state + tf.cast( + self._transition_params, self._compute_dtype + ) + new_state = inputs + tf.reduce_max(transition_scores, [1]) backpointers = tf.argmax(transition_scores, 1) backpointers = tf.cast(backpointers, dtype=tf.int32) return backpointers, new_state @@ -485,9 +487,9 @@ def crf_decode_forward( """ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) - crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params, dtype=inputs.dtype) crf_fwd_layer = tf.keras.layers.RNN( - crf_fwd_cell, return_sequences=True, return_state=True + crf_fwd_cell, return_sequences=True, return_state=True, dtype=inputs.dtype ) return crf_fwd_layer(inputs, state, mask=mask)