Skip to content

Commit b661c98

Browse files
authored
Change CRF layer dtype (#2270)
* Change layer dtype * Cast to self._compute_dtype
1 parent a21a32a commit b661c98

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tensorflow_addons/text/crf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,10 @@ def call(self, inputs, state):
443443
new_state: A [batch_size, num_tags] matrix of new score values.
444444
"""
445445
state = tf.expand_dims(state[0], 2)
446-
transition_scores = state + tf.cast(self._transition_params, state.dtype)
447-
new_state = tf.cast(inputs, state.dtype) + tf.reduce_max(transition_scores, [1])
446+
transition_scores = state + tf.cast(
447+
self._transition_params, self._compute_dtype
448+
)
449+
new_state = inputs + tf.reduce_max(transition_scores, [1])
448450
backpointers = tf.argmax(transition_scores, 1)
449451
backpointers = tf.cast(backpointers, dtype=tf.int32)
450452
return backpointers, new_state
@@ -485,9 +487,9 @@ def crf_decode_forward(
485487
"""
486488
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
487489
mask = tf.sequence_mask(sequence_lengths, tf.shape(inputs)[1])
488-
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
490+
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params, dtype=inputs.dtype)
489491
crf_fwd_layer = tf.keras.layers.RNN(
490-
crf_fwd_cell, return_sequences=True, return_state=True
492+
crf_fwd_cell, return_sequences=True, return_state=True, dtype=inputs.dtype
491493
)
492494
return crf_fwd_layer(inputs, state, mask=mask)
493495

0 commit comments

Comments
 (0)