Skip to content

CRF layer v3.0 #1733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow_addons/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ py_library(
deps = [
"//tensorflow_addons/activations",
"//tensorflow_addons/testing",
"//tensorflow_addons/text",
"//tensorflow_addons/utils",
],
)
Expand Down
244 changes: 244 additions & 0 deletions tensorflow_addons/layers/crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Orginal implementation from keras_contrib/layers/crf
# ==============================================================================
"""Implementing Conditional Random Field layer."""

import tensorflow as tf
from typeguard import typechecked

from tensorflow_addons.text.crf import crf_decode
from tensorflow_addons.utils import types


@tf.keras.utils.register_keras_serializable(package="Addons")
class CRF(tf.keras.layers.Layer):
"""Linear chain conditional random field (CRF).

References:
- [Conditional Random Field](https://en.wikipedia.org/wiki/Conditional_random_field)
"""

@typechecked
def __init__(
self,
units: int,
chain_initializer: types.Initializer = "orthogonal",
use_boundary: bool = True,
boundary_initializer: types.Initializer = "zeros",
use_kernel: bool = True,
**kwargs
):
super().__init__(**kwargs)

# setup mask supporting flag, used by base class (the Layer)
# because base class's init method will set it to False unconditionally
# So this assigned must be executed after call base class's init method
self.supports_masking = True

self.units = units # numbers of tags

self.use_boundary = use_boundary
self.use_kernel = use_kernel
self.chain_initializer = tf.keras.initializers.get(chain_initializer)
self.boundary_initializer = tf.keras.initializers.get(boundary_initializer)

# weights that work as transfer probability of each tags
self.chain_kernel = self.add_weight(
shape=(self.units, self.units),
name="chain_kernel",
initializer=self.chain_initializer,
)

# weight of <START> to tag probability and tag to <END> probability
if self.use_boundary:
self.left_boundary = self.add_weight(
shape=(self.units,),
name="left_boundary",
initializer=self.boundary_initializer,
)
self.right_boundary = self.add_weight(
shape=(self.units,),
name="right_boundary",
initializer=self.boundary_initializer,
)

if self.use_kernel:
self._dense_layer = tf.keras.layers.Dense(
units=self.units, dtype=self.dtype,
)
else:
self._dense_layer = lambda x: tf.cast(x, dtype=self.dtype)

def call(self, inputs, mask=None):
# mask: Tensor(shape=(batch_size, sequence_length), dtype=bool) or None

if mask is not None:
if tf.keras.backend.ndim(mask) != 2:
raise ValueError("Input mask to CRF must have dim 2 if not None")

if mask is not None:
# left padding of mask is not supported, due the underline CRF function
# detect it and report it to user
left_boundary_mask = self._compute_mask_left_boundary(mask)
first_mask = left_boundary_mask[:, 0]
if first_mask is not None and tf.executing_eagerly():
no_left_padding = tf.math.reduce_all(first_mask)
left_padding = not no_left_padding
if left_padding:
raise NotImplementedError(
"Currently, CRF layer do not support left padding"
)

potentials = self._dense_layer(inputs)

# appending boundary probability info
if self.use_boundary:
potentials = self.add_boundary_energy(
potentials, mask, self.left_boundary, self.right_boundary
)

sequence_length = self._get_sequence_length(inputs, mask)

decoded_sequence, _ = self.get_viterbi_decoding(potentials, sequence_length)

return [decoded_sequence, potentials, sequence_length, self.chain_kernel]

def _get_sequence_length(self, input_, mask):
"""Currently underline CRF fucntion (provided by
tensorflow_addons.text.crf) do not support bi-direction masking (left
padding / right padding), it support right padding by tell it the
sequence length.

this function is compute the sequence length from input and
mask.
"""
if mask is not None:
sequence_length = self.mask_to_sequence_length(mask)
else:
# make a mask tensor from input, then used to generate sequence_length
input_energy_shape = tf.shape(input_)
raw_input_shape = tf.slice(input_energy_shape, [0], [2])
alt_mask = tf.ones(raw_input_shape)

sequence_length = self.mask_to_sequence_length(alt_mask)

return sequence_length

def mask_to_sequence_length(self, mask):
"""compute sequence length from mask."""
sequence_length = tf.cast(tf.reduce_sum(tf.cast(mask, tf.int8), 1), tf.int64)
Copy link

@ndrewl ndrewl Jun 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here sums are computed using tf.int8 which will overflow on sequences longer than 127, which will produce negative sequence lengths. Maybe we can cast the mask to tf.int64 right away, and then the outer cast will be unnecessary?

return sequence_length

@staticmethod
def _compute_mask_right_boundary(mask):
"""input mask: 0011100, output left_boundary: 0000100."""
# shift mask to left by 1: 0011100 => 0111000
offset = 1
left_shifted_mask = tf.concat(
[mask[:, offset:], tf.zeros_like(mask[:, :offset])], axis=1
)

# NOTE: below code is different from keras_contrib
# Original code in keras_contrib:
# end_mask = K.cast(
# K.greater(self.shift_left(mask), mask),
# K.floatx()
# )
# has a bug, confirmed
# by the original keras_contrib maintainer
# Luiz Felix (github: lzfelix),

# 0011100 > 0111000 => 0000100
right_boundary = tf.greater(mask, left_shifted_mask)

return right_boundary

@staticmethod
def _compute_mask_left_boundary(mask):
"""input mask: 0011100, output left_boundary: 0010000."""
# shift mask to right by 1: 0011100 => 0001110
offset = 1
right_shifted_mask = tf.concat(
[tf.zeros_like(mask[:, :offset]), mask[:, :-offset]], axis=1
)

# 0011100 > 0001110 => 0010000
left_boundary = tf.greater(
tf.cast(mask, tf.int32), tf.cast(right_shifted_mask, tf.int32)
)
# left_boundary = tf.greater(mask, right_shifted_mask)

return left_boundary

def add_boundary_energy(self, potentials, mask, start, end):
def expand_scalar_to_3d(x):
# expand tensor from shape (x, ) to (1, 1, x)
return tf.reshape(x, (1, 1, -1))

start = expand_scalar_to_3d(start)
end = expand_scalar_to_3d(end)
if mask is None:
potentials = tf.concat(
[potentials[:, :1, :] + start, potentials[:, 1:, :]], axis=1
)
potentials = tf.concat(
[potentials[:, :-1, :], potentials[:, -1:, :] + end], axis=1
)
else:
mask = tf.keras.backend.expand_dims(tf.cast(mask, start.dtype), axis=-1)
start_mask = tf.cast(self._compute_mask_left_boundary(mask), start.dtype)

end_mask = tf.cast(self._compute_mask_right_boundary(mask), end.dtype)
potentials = potentials + start_mask * start
potentials = potentials + end_mask * end
return potentials

def get_viterbi_decoding(self, potentials, sequence_length):
# decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`
decode_tags, best_score = crf_decode(
potentials, self.chain_kernel, sequence_length
)

return decode_tags, best_score

def get_config(self):
# used for loading model from disk
config = {
"units": self.units,
"chain_initializer": tf.keras.initializers.serialize(
self.chain_initializer
),
"use_boundary": self.use_boundary,
"boundary_initializer": tf.keras.initializers.serialize(
self.boundary_initializer
),
"use_kernel": self.use_kernel,
}
base_config = super().get_config()
return {**base_config, **config}

def compute_output_shape(self, input_shape):
output_shape = input_shape[:2]
return output_shape

def compute_mask(self, input_, mask=None):
"""keep mask shape [batch_size, max_seq_len]"""
return mask

@property
def _compute_dtype(self):
# fixed output dtype from underline CRF functions
return tf.int32
Loading