From 009df4a39f1fafcac77d7725be0b4756cd2c415b Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 6 Dec 2020 17:38:27 -0800 Subject: [PATCH 1/3] Change layer dtype --- tensorflow_addons/text/crf.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index 8bc9901128..404a81d9c4 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -442,9 +442,9 @@ def call(self, inputs, state): backpointers: A [batch_size, num_tags] matrix of backpointers. new_state: A [batch_size, num_tags] matrix of new score values. """ - state = tf.expand_dims(state[0], 2) - transition_scores = state + tf.cast(self._transition_params, state.dtype) - new_state = tf.cast(inputs, state.dtype) + tf.reduce_max(transition_scores, [1]) + state = tf.cast(tf.expand_dims(state[0], 2), inputs.dtype) + transition_scores = state + tf.cast(self._transition_params, inputs.dtype) + new_state = inputs + tf.reduce_max(transition_scores, [1]) backpointers = tf.argmax(transition_scores, 1) backpointers = tf.cast(backpointers, dtype=tf.int32) return backpointers, new_state @@ -485,9 +485,9 @@ def crf_decode_forward( """ sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32) mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1]) - crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params, dtype=inputs.dtype) crf_fwd_layer = tf.keras.layers.RNN( - crf_fwd_cell, return_sequences=True, return_state=True + crf_fwd_cell, return_sequences=True, return_state=True, dtype=inputs.dtype ) return crf_fwd_layer(inputs, state, mask=mask) From 8bfdfd93f64aab29b267851785abf4ef17e30b6f Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 6 Dec 2020 17:42:37 -0800 Subject: [PATCH 2/3] Cast to self._compute_dtype --- tensorflow_addons/text/crf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index 404a81d9c4..0dc73de520 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -442,8 +442,8 @@ def call(self, inputs, state): backpointers: A [batch_size, num_tags] matrix of backpointers. new_state: A [batch_size, num_tags] matrix of new score values. """ - state = tf.cast(tf.expand_dims(state[0], 2), inputs.dtype) - transition_scores = state + tf.cast(self._transition_params, inputs.dtype) + state = tf.expand_dims(state[0], 2) + transition_scores = state + tf.cast(self._transition_params, self._compute_dtype) new_state = inputs + tf.reduce_max(transition_scores, [1]) backpointers = tf.argmax(transition_scores, 1) backpointers = tf.cast(backpointers, dtype=tf.int32) From ff1bd66eed0c00d227e467788cfb2bf2a098bb0e Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Sun, 6 Dec 2020 17:49:23 -0800 Subject: [PATCH 3/3] Format codes --- tensorflow_addons/text/crf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/text/crf.py b/tensorflow_addons/text/crf.py index 0dc73de520..63f0e6ae6e 100644 --- a/tensorflow_addons/text/crf.py +++ b/tensorflow_addons/text/crf.py @@ -443,7 +443,9 @@ def call(self, inputs, state): new_state: A [batch_size, num_tags] matrix of new score values. """ state = tf.expand_dims(state[0], 2) - transition_scores = state + tf.cast(self._transition_params, self._compute_dtype) + transition_scores = state + tf.cast( + self._transition_params, self._compute_dtype + ) new_state = inputs + tf.reduce_max(transition_scores, [1]) backpointers = tf.argmax(transition_scores, 1) backpointers = tf.cast(backpointers, dtype=tf.int32)