Skip to content

CRF crf_fwd_cell caused dtype error with mixed presion training #2231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
xuxingya opened this issue Nov 6, 2020 · 9 comments
Closed

CRF crf_fwd_cell caused dtype error with mixed presion training #2231

xuxingya opened this issue Nov 6, 2020 · 9 comments
Labels

Comments

@xuxingya
Copy link

xuxingya commented Nov 6, 2020

System information

  • Linux Ubuntu 18.04
  • Tensorflow 2.2.0 or 2.4.1
  • TensorFlow-Addons 0.11.2:
  • Python version: 3.7.6
  • Is GPU used? (yes/no): yes

Describe the bug
I am training a model by mixed precision, where I used viterbi_sequence, _ = tfa.text.crf_decode(inputs, self.transitions, sequence_lengths).
The error below comes even I am sure my inputs are all float32 type:
InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a half tensor but is a float tensor [Op:AddV2]
Then I found in text/crf.py function crf_decode_forward:
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params,), The dtype of this RNN Cell is not specified, so it will always uses the global precision policy, which changes the dtype of inputs to float16. I think this should be changed to:
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params, dtype=inputs.dtype)

@xuxingya xuxingya changed the title The RNN used on text/crf.py L:462 will alwasys use global precision which can't not be speficied outside. CRF crf_fwd_cell caused dtype error with mixed presion training Nov 9, 2020
@bhack bhack added the crf label Nov 14, 2020
@bhack
Copy link
Contributor

bhack commented Nov 14, 2020

I think that CRF is orphan currently. Do want to help to maintain it? /cc @seanpmorgan

@Harsh188
Copy link
Contributor

If you are interested please consider checking out #337. I've been trying to create a tutorial for CRF however I'm fairly inexperienced with CRFs and haven't had any positive results so far.

@WindQAQ
Copy link
Member

WindQAQ commented Nov 23, 2020

Hi @xuxingya, can you provide a runnable example that can give the errors? Thank you!

@xuxingya
Copy link
Author

This is the runnable example. It's also a CRF Layer implementation.

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow_addons.text.crf import crf_log_likelihood
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Embedding, Bidirectional, GRU, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)


class CRF(tf.keras.layers.Layer):
    def __init__(self, chain_initializer="orthogonal", **kwargs):
        super(CRF, self).__init__(**kwargs)
        self.chain_initializer = tf.keras.initializers.get(chain_initializer)
        self.transitions = None
        self.supports_masking = True
        self.mask = None
        self.accuracy_fn = tf.keras.metrics.Accuracy()

    def get_config(self):
        config = super(CRF, self).get_config()
        config.update({
            "chain_initializer": "orthogonal"
        })
        return config

    def build(self, input_shape):
        assert len(input_shape) == 3
        units = input_shape[-1]
        self.transitions = self.add_weight(
            name="transitions",
            shape=[units, units],
            initializer=self.chain_initializer,
        )

    def call(self, inputs, mask=None, training=None):
        if mask is None:
            raw_input_shape = tf.slice(tf.shape(inputs), [0], [2])
            mask = tf.ones(raw_input_shape)
        sequence_lengths = K.sum(K.cast(mask, 'int32'), axis=-1)

        viterbi_sequence, _ = tfa.text.crf_decode(
            inputs, self.transitions, sequence_lengths
        )
        return viterbi_sequence, inputs, sequence_lengths, self.transitions

class ModelWithCRFLoss(tf.keras.Model):
    """Wrapper around the base model for custom training logic."""

    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.accuracy_fn = tf.keras.metrics.Accuracy(name='accuracy')

    def call(self, inputs, training=False):
        return self.base_model(inputs)

    def compute_loss(self, x, y, sample_weight, training=False):
        y_pred = self(x, training=training)
        viterbi_sequence, potentials, sequence_length, chain_kernel = y_pred
        # we now add the CRF loss:
        crf_loss = -crf_log_likelihood(potentials, y, sequence_length, chain_kernel)[0]
        if sample_weight is not None:
            crf_loss = crf_loss * sample_weight
        return viterbi_sequence, sequence_length, tf.reduce_mean(crf_loss)

    def accuracy(self, y_true, y_pred):
        viterbi_sequence, potentials, sequence_length, chain_kernel = y_pred
        sample_weights = tf.sequence_mask(sequence_length, y_true.shape[1])
        return self.accuracy_fn(y_true, viterbi_sequence, sample_weights)

    def train_step(self, data):
        x, y, sample_weight = unpack_data(data)

        with tf.GradientTape() as tape:
            viterbi_sequence, sequence_length, crf_loss = self.compute_loss(
                x, y, sample_weight, training=True
            )
        gradients = tape.gradient(crf_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        self.accuracy_fn.update_state(y, viterbi_sequence, tf.sequence_mask(sequence_length, y.shape[1]))

        return {"crf_loss": crf_loss, 'accuracy': self.accuracy_fn.result()}

    def test_step(self, data):
        x, y, sample_weight = unpack_data(data)
        viterbi_sequence, sequence_length, crf_loss = self.compute_loss(x, y, sample_weight)
        self.accuracy_fn.update_state(y, viterbi_sequence, tf.sequence_mask(sequence_length, y.shape[1]))
        return {"crf_loss_val": crf_loss, 'val_accuracy': self.accuracy_fn.result()}

def unpack_data(data):
    if len(data) == 2:
        return data[0], data[1], None
    elif len(data) == 3:
        return data
    else:
        raise TypeError("Expected data to be a tuple of size 2 or 3.")

def test():
    inputs = Input(shape=(None,), dtype='int32')
    output = Embedding(100, 40, trainable=True, mask_zero=True)(inputs)
    output = Bidirectional(GRU(64, return_sequences=True))(output)
    output = Dense(9, activation=None)(output)
    crf = CRF(dtype='float32')
    output = crf(output)
    base_model = Model(inputs, output)
    model = ModelWithCRFLoss(base_model)
    model.compile(optimizer='adam')

    x = np.array([[5, 2, 3] * 3] * 100)
    y = np.array([[1, 2, 3] * 3] * 100)

    model.fit(x=x, y=y, epochs=10, batch_size=4, validation_split=0.1)


if __name__ == '__main__':
    test()

@WindQAQ
Copy link
Member

WindQAQ commented Nov 24, 2020

Hi @xuxingya, I run the codes on tfa-nightly without errors. The errors are probably solved in f429133. Can you try again with tfa-nightly? Thank you!

@xuxingya
Copy link
Author

Yes, the tfa-nightly branch solved this by doing dtype conversion in CrfDecodeForwardRnnCell:

transition_scores = state + tf.cast(self._transition_params, state.dtype)
new_state = tf.cast(inputs, state.dtype) + tf.reduce_max(transition_scores, [1])

And when running the model, the dtypes are:

state:<dtype: 'float32'>, inputs: <dtype: 'float16'>
new_state: <dtype: 'float32'>

Howerver, according to this document, it looks like the computations shoud be in f16 while value kept in f32. I wonder if this change is better:
In crf_decode_forward:

crf_fwd_layer = tf.keras.layers.RNN(
        crf_fwd_cell, return_sequences=True, return_state=True, dtype=inputs.dtype
    ) 

@WindQAQ
Copy link
Member

WindQAQ commented Nov 25, 2020

Thanks for the clarification! If the computation should be in f16, isn't it more suitable to cast all other tensors to inputs.dtype in call function?

BTW, currently, the _transition_params is a large Tensor in CrfDecodeForwardRnnCell, do you think it's better to mark it as non-trainable variables via self.add_weight? In this way, the _transition_params can be automatically casted to global policy when we create it, and it's also better when serializing the model. We convert it to list of values in get_config now, which is very anti-pattern I think. It might decrease some performance when switching from Tensor to Variable though.

@xuxingya
Copy link
Author

When using mixed precision, the output of the tf.layers.RNN here will still be in float32. I think this is because tf can't apply the mixed pocily to layers not implemented in model.
Waht do you mean by mark the _transition_params as non-trainable variables? Isn't the CRF layer meant to train this?

@WindQAQ
Copy link
Member

WindQAQ commented Nov 26, 2020

When using mixed precision, the output of the tf.layers.RNN here will still be in float32. I think this is because tf can't apply the mixed pocily to layers not implemented in model.

Got it! Would you like to submit a PR for it?

https://github.com/tensorflow/addons/blob/master/tensorflow_addons/text/crf.py#L449

Waht do you mean by mark the _transition_params as non-trainable variables? Isn't the CRF layer meant to train this?

Never mind. I just figure out what CrfDecodeForwardRnnCell does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants