@@ -443,8 +443,10 @@ def call(self, inputs, state):
443
443
new_state: A [batch_size, num_tags] matrix of new score values.
444
444
"""
445
445
state = tf .expand_dims (state [0 ], 2 )
446
- transition_scores = state + tf .cast (self ._transition_params , state .dtype )
447
- new_state = tf .cast (inputs , state .dtype ) + tf .reduce_max (transition_scores , [1 ])
446
+ transition_scores = state + tf .cast (
447
+ self ._transition_params , self ._compute_dtype
448
+ )
449
+ new_state = inputs + tf .reduce_max (transition_scores , [1 ])
448
450
backpointers = tf .argmax (transition_scores , 1 )
449
451
backpointers = tf .cast (backpointers , dtype = tf .int32 )
450
452
return backpointers , new_state
@@ -485,9 +487,9 @@ def crf_decode_forward(
485
487
"""
486
488
sequence_lengths = tf .cast (sequence_lengths , dtype = tf .int32 )
487
489
mask = tf .sequence_mask (sequence_lengths , tf .shape (inputs )[1 ])
488
- crf_fwd_cell = CrfDecodeForwardRnnCell (transition_params )
490
+ crf_fwd_cell = CrfDecodeForwardRnnCell (transition_params , dtype = inputs . dtype )
489
491
crf_fwd_layer = tf .keras .layers .RNN (
490
- crf_fwd_cell , return_sequences = True , return_state = True
492
+ crf_fwd_cell , return_sequences = True , return_state = True , dtype = inputs . dtype
491
493
)
492
494
return crf_fwd_layer (inputs , state , mask = mask )
493
495
0 commit comments