Skip to content

CRF layer v3.0 #1733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed

CRF layer v3.0 #1733

wants to merge 17 commits into from

Conversation

gabrieldemarmiesse
Copy link
Member

@gabrieldemarmiesse gabrieldemarmiesse commented Apr 26, 2020

With a subclassing approch, we have a nicer API and it's very flexible.

Works only with TF 2.2+

@howl-anderson for the review and the CLA

The plan is to show users how to do the subclassing for the CRF. We shouldn't provide and API to save them some code there because it's going to become very complex to design a good API and to maintain it later on.

So the CRF layer is a public API and for the CRF loss, we give a good tutorial about subclassing.

Quick tutorial right now:

import tensorflow as tf
from tensorflow_addons.layers.crf import CRF
from tensorflow_addons.text.crf import crf_log_likelihood

def unpack_data(data):
    if len(data) == 2:
        return data[0], data[1], None
    elif len(data) == 3:
        return data
    else:
        raise TypeError("Expected data to be a tuple of size 2 or 3.")


class ModelWithCRFLoss(tf.keras.Model):
    """Wrapper around the base model for custom training logic."""

    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def call(self, inputs):
        return self.base_model(inputs)

    def compute_loss(self, x, y, sample_weights, training=False):
        y_pred = self(x, training=training)
        _, potentials, sequence_length, chain_kernel = y_pred

        crf_loss = -crf_log_likelihood(potentials, y, sequence_length, chain_kernel)[0]

        if sample_weights is not None:
            crf_loss = crf_loss * sample_weights

        return tf.reduce_mean(crf_loss), sum(self.losses)

    def train_step(self, data):
        x, y, sample_weight = unpack_data(data)

        with tf.GradientTape() as tape:
            crf_loss, internal_losses = self.compute_loss(
                x, y, sample_weight, training=True
            )
            total_loss = crf_loss + internal_losses

        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {"crf_loss": crf_loss, "internal_losses": internal_losses}

    def test_step(self, data):
        x, y, sample_weight = unpack_data(data)
        crf_loss, internal_losses = self.compute_loss(x, y, sample_weight)
        return {"crf_loss_val": crf_loss, "internal_losses_val": internal_losses}


x_np, y_np = get_test_data()

x_input = tf.keras.layers.Input(shape=x_np.shape[1:])
crf_outputs = CRF(5)(x_input)
base_model = tf.keras.Model(x_input, crf_outputs)
model = ModelWithCRFLoss(base_model)

model .compile("adam")
model .fit(x=x_np, y=y_np)
model .evaluate(x_np, y_np)
model .predict(x_np)
model.save("my_model.tf")

If some users want to try this feature before it's merged, we have some wheels available

@googlebot
Copy link

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@seanpmorgan seanpmorgan added blocked Pending something elses completion and removed blocked Pending something elses completion labels Apr 26, 2020
return tf.reduce_mean(crf_loss), sum(self.losses)

def train_step(self, data):
x, y, sample_weight = unpack_data(data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent naming - it's sample_weights, plural, in compute_loss()

def test_step(self, data):
x, y, sample_weight = unpack_data(data)
crf_loss, internal_losses = self.compute_loss(x, y, sample_weight)
return {"crf_loss_val": crf_loss, "internal_losses_val": internal_losses}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix val_ is already added

@YanZhu1105
Copy link

Hi,
Is it possible to include an example with sample weights?
Also an example where the input of model.fit is a generator which yield (x, y, sample_weight) for eatch batch?

Thanks!


def mask_to_sequence_length(self, mask):
"""compute sequence length from mask."""
sequence_length = tf.cast(tf.reduce_sum(tf.cast(mask, tf.int8), 1), tf.int64)
Copy link

@ndrewl ndrewl Jun 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here sums are computed using tf.int8 which will overflow on sequences longer than 127, which will produce negative sequence lengths. Maybe we can cast the mask to tf.int64 right away, and then the outer cast will be unnecessary?

@luozhouyang
Copy link

Any updates?

@ndrewl
Copy link

ndrewl commented Jul 12, 2020

@gabrieldemarmiesse, are you still going to work on this PR?

@gabrieldemarmiesse
Copy link
Member Author

Sorry, I'm very busy nowadays, somebody else is more than welcome to take this branch and open a new pull request with it :)

@jaspersjsun
Copy link
Contributor

@gabrieldemarmiesse Hey! I'm working on a model with CRF layer recently and your solution here is very helpful. I'd gladly help finish this PR if you are not available currently ;)

@gabrieldemarmiesse
Copy link
Member Author

I'm glad it helped you! Feel free to pull this branch into your fork and open a new PR :)

@gabrieldemarmiesse
Copy link
Member Author

Closing in favor of #1999

@DachuanZhao
Copy link

With a subclassing approch, we have a nicer API and it's very flexible.

Works only with TF 2.2+

@howl-anderson for the review and the CLA

The plan is to show users how to do the subclassing for the CRF. We shouldn't provide and API to save them some code there because it's going to become very complex to design a good API and to maintain it later on.

So the CRF layer is a public API and for the CRF loss, we give a good tutorial about subclassing.

Quick tutorial right now:

import tensorflow as tf
from tensorflow_addons.layers.crf import CRF
from tensorflow_addons.text.crf import crf_log_likelihood

def unpack_data(data):
    if len(data) == 2:
        return data[0], data[1], None
    elif len(data) == 3:
        return data
    else:
        raise TypeError("Expected data to be a tuple of size 2 or 3.")


class ModelWithCRFLoss(tf.keras.Model):
    """Wrapper around the base model for custom training logic."""

    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def call(self, inputs):
        return self.base_model(inputs)

    def compute_loss(self, x, y, sample_weights, training=False):
        y_pred = self(x, training=training)
        _, potentials, sequence_length, chain_kernel = y_pred

        crf_loss = -crf_log_likelihood(potentials, y, sequence_length, chain_kernel)[0]

        if sample_weights is not None:
            crf_loss = crf_loss * sample_weights

        return tf.reduce_mean(crf_loss), sum(self.losses)

    def train_step(self, data):
        x, y, sample_weight = unpack_data(data)

        with tf.GradientTape() as tape:
            crf_loss, internal_losses = self.compute_loss(
                x, y, sample_weight, training=True
            )
            total_loss = crf_loss + internal_losses

        gradients = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {"crf_loss": crf_loss, "internal_losses": internal_losses}

    def test_step(self, data):
        x, y, sample_weight = unpack_data(data)
        crf_loss, internal_losses = self.compute_loss(x, y, sample_weight)
        return {"crf_loss_val": crf_loss, "internal_losses_val": internal_losses}


x_np, y_np = get_test_data()

x_input = tf.keras.layers.Input(shape=x_np.shape[1:])
crf_outputs = CRF(5)(x_input)
base_model = tf.keras.Model(x_input, crf_outputs)
model = ModelWithCRFLoss(base_model)

model .compile("adam")
model .fit(x=x_np, y=y_np)
model .evaluate(x_np, y_np)
model .predict(x_np)
model.save("my_model.tf")

If some users want to try this feature before it's merged, we have some wheels available

Hi ~ How to use this code to build a bi-lstm-crf model ? like this ?

x_input = tf.keras.layers.Input(shape=x_np.shape[1:])
bilstm_output = bi_lstm_model(x_input)
crf_outputs = CRF(5)(bilstm_output)
base_model = tf.keras.Model(x_input, crf_outputs)
model = ModelWithCRFLoss(base_model)

@xuxingya
Copy link

I build a crf tool for TF2, you can refer to tf2crf. And pip install tf2crf.

@DachuanZhao
Copy link

I build a crf tool for TF2, you can refer to tf2crf. And pip install tf2crf.

What's the difference between your CRF layer and tensorflow_addons.layers.crf

@xuxingya
Copy link

xuxingya commented Mar 15, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.