-
Notifications
You must be signed in to change notification settings - Fork 614
CRF layer v3.0 #1733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
gabrieldemarmiesse
wants to merge
17
commits into
tensorflow:master
from
gabrieldemarmiesse:crf_layer_again
Closed
CRF layer v3.0 #1733
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
8d754a1
Squash all.
howl-anderson a080aaa
Merge branch 'master' into trying_to_squash
gabrieldemarmiesse 037549c
Cleanup for easier review.
gabrieldemarmiesse e4cdfcb
Calming the angry bazel.
gabrieldemarmiesse 76a4375
Fix the strange bug.
gabrieldemarmiesse bf691c8
Replaced one bug by another bug.
gabrieldemarmiesse 413f242
Minor simplification.
gabrieldemarmiesse a6afeb9
Fix unused parameter.
gabrieldemarmiesse 3c0f306
Simplified the signature.
gabrieldemarmiesse 4517e98
Merge branch 'master' into trying_to_squash
gabrieldemarmiesse fa347ae
Removing boilerplate
gabrieldemarmiesse 4f820b4
Unused import.
gabrieldemarmiesse 89111ff
CRF layer v3.0
gabrieldemarmiesse 35021a0
Finish the conversion.
gabrieldemarmiesse bb68d01
Some renaming here and there.
gabrieldemarmiesse 152eb34
Merge branch 'master' into crf_layer_again
gabrieldemarmiesse cc74721
Added a test where some training is done after reloading the model.
gabrieldemarmiesse File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Orginal implementation from keras_contrib/layers/crf | ||
# ============================================================================== | ||
"""Implementing Conditional Random Field layer.""" | ||
|
||
import tensorflow as tf | ||
from typeguard import typechecked | ||
|
||
from tensorflow_addons.text.crf import crf_decode | ||
from tensorflow_addons.utils import types | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package="Addons") | ||
class CRF(tf.keras.layers.Layer): | ||
"""Linear chain conditional random field (CRF). | ||
|
||
References: | ||
- [Conditional Random Field](https://en.wikipedia.org/wiki/Conditional_random_field) | ||
""" | ||
|
||
@typechecked | ||
def __init__( | ||
self, | ||
units: int, | ||
chain_initializer: types.Initializer = "orthogonal", | ||
use_boundary: bool = True, | ||
boundary_initializer: types.Initializer = "zeros", | ||
use_kernel: bool = True, | ||
**kwargs | ||
): | ||
super().__init__(**kwargs) | ||
|
||
# setup mask supporting flag, used by base class (the Layer) | ||
# because base class's init method will set it to False unconditionally | ||
# So this assigned must be executed after call base class's init method | ||
self.supports_masking = True | ||
|
||
self.units = units # numbers of tags | ||
|
||
self.use_boundary = use_boundary | ||
self.use_kernel = use_kernel | ||
self.chain_initializer = tf.keras.initializers.get(chain_initializer) | ||
self.boundary_initializer = tf.keras.initializers.get(boundary_initializer) | ||
|
||
# weights that work as transfer probability of each tags | ||
self.chain_kernel = self.add_weight( | ||
shape=(self.units, self.units), | ||
name="chain_kernel", | ||
initializer=self.chain_initializer, | ||
) | ||
|
||
# weight of <START> to tag probability and tag to <END> probability | ||
if self.use_boundary: | ||
self.left_boundary = self.add_weight( | ||
shape=(self.units,), | ||
name="left_boundary", | ||
initializer=self.boundary_initializer, | ||
) | ||
self.right_boundary = self.add_weight( | ||
shape=(self.units,), | ||
name="right_boundary", | ||
initializer=self.boundary_initializer, | ||
) | ||
|
||
if self.use_kernel: | ||
self._dense_layer = tf.keras.layers.Dense( | ||
units=self.units, dtype=self.dtype, | ||
) | ||
else: | ||
self._dense_layer = lambda x: tf.cast(x, dtype=self.dtype) | ||
|
||
def call(self, inputs, mask=None): | ||
# mask: Tensor(shape=(batch_size, sequence_length), dtype=bool) or None | ||
|
||
if mask is not None: | ||
if tf.keras.backend.ndim(mask) != 2: | ||
raise ValueError("Input mask to CRF must have dim 2 if not None") | ||
|
||
if mask is not None: | ||
# left padding of mask is not supported, due the underline CRF function | ||
# detect it and report it to user | ||
left_boundary_mask = self._compute_mask_left_boundary(mask) | ||
first_mask = left_boundary_mask[:, 0] | ||
if first_mask is not None and tf.executing_eagerly(): | ||
no_left_padding = tf.math.reduce_all(first_mask) | ||
left_padding = not no_left_padding | ||
if left_padding: | ||
raise NotImplementedError( | ||
"Currently, CRF layer do not support left padding" | ||
) | ||
|
||
potentials = self._dense_layer(inputs) | ||
|
||
# appending boundary probability info | ||
if self.use_boundary: | ||
potentials = self.add_boundary_energy( | ||
potentials, mask, self.left_boundary, self.right_boundary | ||
) | ||
|
||
sequence_length = self._get_sequence_length(inputs, mask) | ||
|
||
decoded_sequence, _ = self.get_viterbi_decoding(potentials, sequence_length) | ||
|
||
return [decoded_sequence, potentials, sequence_length, self.chain_kernel] | ||
|
||
def _get_sequence_length(self, input_, mask): | ||
"""Currently underline CRF fucntion (provided by | ||
tensorflow_addons.text.crf) do not support bi-direction masking (left | ||
padding / right padding), it support right padding by tell it the | ||
sequence length. | ||
|
||
this function is compute the sequence length from input and | ||
mask. | ||
""" | ||
if mask is not None: | ||
sequence_length = self.mask_to_sequence_length(mask) | ||
else: | ||
# make a mask tensor from input, then used to generate sequence_length | ||
input_energy_shape = tf.shape(input_) | ||
raw_input_shape = tf.slice(input_energy_shape, [0], [2]) | ||
alt_mask = tf.ones(raw_input_shape) | ||
|
||
sequence_length = self.mask_to_sequence_length(alt_mask) | ||
|
||
return sequence_length | ||
|
||
def mask_to_sequence_length(self, mask): | ||
"""compute sequence length from mask.""" | ||
sequence_length = tf.cast(tf.reduce_sum(tf.cast(mask, tf.int8), 1), tf.int64) | ||
return sequence_length | ||
|
||
@staticmethod | ||
def _compute_mask_right_boundary(mask): | ||
"""input mask: 0011100, output left_boundary: 0000100.""" | ||
# shift mask to left by 1: 0011100 => 0111000 | ||
offset = 1 | ||
left_shifted_mask = tf.concat( | ||
[mask[:, offset:], tf.zeros_like(mask[:, :offset])], axis=1 | ||
) | ||
|
||
# NOTE: below code is different from keras_contrib | ||
# Original code in keras_contrib: | ||
# end_mask = K.cast( | ||
# K.greater(self.shift_left(mask), mask), | ||
# K.floatx() | ||
# ) | ||
# has a bug, confirmed | ||
# by the original keras_contrib maintainer | ||
# Luiz Felix (github: lzfelix), | ||
|
||
# 0011100 > 0111000 => 0000100 | ||
right_boundary = tf.greater(mask, left_shifted_mask) | ||
|
||
return right_boundary | ||
|
||
@staticmethod | ||
def _compute_mask_left_boundary(mask): | ||
"""input mask: 0011100, output left_boundary: 0010000.""" | ||
# shift mask to right by 1: 0011100 => 0001110 | ||
offset = 1 | ||
right_shifted_mask = tf.concat( | ||
[tf.zeros_like(mask[:, :offset]), mask[:, :-offset]], axis=1 | ||
) | ||
|
||
# 0011100 > 0001110 => 0010000 | ||
left_boundary = tf.greater( | ||
tf.cast(mask, tf.int32), tf.cast(right_shifted_mask, tf.int32) | ||
) | ||
# left_boundary = tf.greater(mask, right_shifted_mask) | ||
|
||
return left_boundary | ||
|
||
def add_boundary_energy(self, potentials, mask, start, end): | ||
def expand_scalar_to_3d(x): | ||
# expand tensor from shape (x, ) to (1, 1, x) | ||
return tf.reshape(x, (1, 1, -1)) | ||
|
||
start = expand_scalar_to_3d(start) | ||
end = expand_scalar_to_3d(end) | ||
if mask is None: | ||
potentials = tf.concat( | ||
[potentials[:, :1, :] + start, potentials[:, 1:, :]], axis=1 | ||
) | ||
potentials = tf.concat( | ||
[potentials[:, :-1, :], potentials[:, -1:, :] + end], axis=1 | ||
) | ||
else: | ||
mask = tf.keras.backend.expand_dims(tf.cast(mask, start.dtype), axis=-1) | ||
start_mask = tf.cast(self._compute_mask_left_boundary(mask), start.dtype) | ||
|
||
end_mask = tf.cast(self._compute_mask_right_boundary(mask), end.dtype) | ||
potentials = potentials + start_mask * start | ||
potentials = potentials + end_mask * end | ||
return potentials | ||
|
||
def get_viterbi_decoding(self, potentials, sequence_length): | ||
# decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32` | ||
decode_tags, best_score = crf_decode( | ||
potentials, self.chain_kernel, sequence_length | ||
) | ||
|
||
return decode_tags, best_score | ||
|
||
def get_config(self): | ||
# used for loading model from disk | ||
config = { | ||
"units": self.units, | ||
"chain_initializer": tf.keras.initializers.serialize( | ||
self.chain_initializer | ||
), | ||
"use_boundary": self.use_boundary, | ||
"boundary_initializer": tf.keras.initializers.serialize( | ||
self.boundary_initializer | ||
), | ||
"use_kernel": self.use_kernel, | ||
} | ||
base_config = super().get_config() | ||
return {**base_config, **config} | ||
|
||
def compute_output_shape(self, input_shape): | ||
output_shape = input_shape[:2] | ||
return output_shape | ||
|
||
def compute_mask(self, input_, mask=None): | ||
"""keep mask shape [batch_size, max_seq_len]""" | ||
return mask | ||
|
||
@property | ||
def _compute_dtype(self): | ||
# fixed output dtype from underline CRF functions | ||
return tf.int32 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here sums are computed using
tf.int8
which will overflow on sequences longer than 127, which will produce negative sequence lengths. Maybe we can cast the mask totf.int64
right away, and then the outer cast will be unnecessary?