-
Notifications
You must be signed in to change notification settings - Fork 614
CRF layer v3.0 #1733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CRF layer v3.0 #1733
Conversation
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
return tf.reduce_mean(crf_loss), sum(self.losses) | ||
|
||
def train_step(self, data): | ||
x, y, sample_weight = unpack_data(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent naming - it's sample_weights
, plural, in compute_loss()
def test_step(self, data): | ||
x, y, sample_weight = unpack_data(data) | ||
crf_loss, internal_losses = self.compute_loss(x, y, sample_weight) | ||
return {"crf_loss_val": crf_loss, "internal_losses_val": internal_losses} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefix val_
is already added
Hi, Thanks! |
|
||
def mask_to_sequence_length(self, mask): | ||
"""compute sequence length from mask.""" | ||
sequence_length = tf.cast(tf.reduce_sum(tf.cast(mask, tf.int8), 1), tf.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here sums are computed using tf.int8
which will overflow on sequences longer than 127, which will produce negative sequence lengths. Maybe we can cast the mask to tf.int64
right away, and then the outer cast will be unnecessary?
Any updates? |
@gabrieldemarmiesse, are you still going to work on this PR? |
Sorry, I'm very busy nowadays, somebody else is more than welcome to take this branch and open a new pull request with it :) |
@gabrieldemarmiesse Hey! I'm working on a model with CRF layer recently and your solution here is very helpful. I'd gladly help finish this PR if you are not available currently ;) |
I'm glad it helped you! Feel free to pull this branch into your fork and open a new PR :) |
Closing in favor of #1999 |
Hi ~ How to use this code to build a bi-lstm-crf model ? like this ?
|
I build a crf tool for TF2, you can refer to tf2crf. And pip install tf2crf. |
What's the difference between your CRF layer and |
Because during a longtime the tfa.layers.crf has problems with mixed
precision and other bugs. And it has 4 outputs when predicting. So I implemented a CRF layer using the CRF
functions in tfa.text.crf, which is more like old CRF of keras_contrib.
…On Mon, Mar 15, 2021 at 9:35 AM DachuanZhao ***@***.***> wrote:
I build a crf tool for TF2, you can refer to tf2crf
<https://github.com/xuxingya/tf2crf>. And pip install tf2crf.
What's the difference between your CRF layer and
tensorflow_addons.layers.crf ?
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#1733 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADFZVRFOGNEYKJYP4XGW3ELTDVP73ANCNFSM4MRF3B4Q>
.
|
With a subclassing approch, we have a nicer API and it's very flexible.
Works only with TF 2.2+
@howl-anderson for the review and the CLA
The plan is to show users how to do the subclassing for the CRF. We shouldn't provide and API to save them some code there because it's going to become very complex to design a good API and to maintain it later on.
So the CRF layer is a public API and for the CRF loss, we give a good tutorial about subclassing.
Quick tutorial right now:
If some users want to try this feature before it's merged, we have some wheels available