-
Notifications
You must be signed in to change notification settings - Fork 614
CRF crf_fwd_cell caused dtype error with mixed presion training #2231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I think that CRF is orphan currently. Do want to help to maintain it? /cc @seanpmorgan |
If you are interested please consider checking out #337. I've been trying to create a tutorial for CRF however I'm fairly inexperienced with CRFs and haven't had any positive results so far. |
Hi @xuxingya, can you provide a runnable example that can give the errors? Thank you! |
This is the runnable example. It's also a CRF Layer implementation.
|
Yes, the tfa-nightly branch solved this by doing dtype conversion in CrfDecodeForwardRnnCell:
And when running the model, the dtypes are:
Howerver, according to this document, it looks like the computations shoud be in f16 while value kept in f32. I wonder if this change is better:
|
Thanks for the clarification! If the computation should be in f16, isn't it more suitable to cast all other tensors to BTW, currently, the |
When using mixed precision, the output of the tf.layers.RNN here will still be in float32. I think this is because tf can't apply the mixed pocily to layers not implemented in model. |
Got it! Would you like to submit a PR for it? https://github.com/tensorflow/addons/blob/master/tensorflow_addons/text/crf.py#L449
Never mind. I just figure out what |
System information
Describe the bug
I am training a model by mixed precision, where I used
viterbi_sequence, _ = tfa.text.crf_decode(inputs, self.transitions, sequence_lengths)
.The error below comes even I am sure my inputs are all float32 type:
InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:AddV2]
Then I found in text/crf.py function crf_decode_forward:
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params,)
, The dtype of this RNN Cell is not specified, so it will always uses the global precision policy, which changes the dtype of inputs to float16. I think this should be changed to:crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params, dtype=inputs.dtype)
The text was updated successfully, but these errors were encountered: