From de94eab7e93a796ada15f5f2e19fb5abd5ea3097 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 5 Aug 2022 22:10:48 +0000 Subject: [PATCH 01/45] first draft --- examples/bert/bert_train.py | 177 +--------------- keras_nlp/__init__.py | 1 + keras_nlp/applications/__init__.py | 17 ++ .../applications/bert.py | 199 +++++++++++++++++- 4 files changed, 221 insertions(+), 173 deletions(-) create mode 100644 keras_nlp/applications/__init__.py rename examples/bert/bert_model.py => keras_nlp/applications/bert.py (51%) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index fc787cfd2e..7d941b2fa9 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -21,6 +21,10 @@ from absl import logging from tensorflow import keras +from keras_nlp.applications.bert import ( + BertLanguageModel, + BertEncoder, +) from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG from examples.bert.bert_config import TRAINING_CONFIG @@ -90,175 +94,6 @@ ) -class MaskedLMHead(keras.layers.Layer): - """Masked language model network head for BERT. - - This layer implements a masked language model based on the provided - transformer based encoder. It assumes that the encoder network being passed - has a "get_embedding_table()" method. - - Example: - ```python - encoder=modeling.networks.BertEncoder(...) - lm_layer=MaskedLMHead(embedding_table=encoder.get_embedding_table()) - ``` - - Args: - embedding_table: The embedding table from encoder network. - inner_activation: The activation, if any, for the inner dense layer. - initializer: The initializer for the dense layer. Defaults to a Glorot - uniform initializer. - output: The output style for this layer. Can be either 'logits' or - 'predictions'. - """ - - def __init__( - self, - embedding_table, - inner_activation="gelu", - initializer="glorot_uniform", - **kwargs, - ): - super().__init__(**kwargs) - self.embedding_table = embedding_table - self.inner_activation = keras.activations.get(inner_activation) - self.initializer = initializer - - def build(self, input_shape): - self._vocab_size, hidden_size = self.embedding_table.shape - self.dense = keras.layers.Dense( - hidden_size, - activation=self.inner_activation, - kernel_initializer=self.initializer, - name="transform/dense", - ) - self.layer_norm = keras.layers.LayerNormalization( - axis=-1, epsilon=1e-12, name="transform/LayerNorm" - ) - self.bias = self.add_weight( - "output_bias/bias", - shape=(self._vocab_size,), - initializer="zeros", - trainable=True, - ) - - super().build(input_shape) - - def call(self, sequence_data, masked_positions): - masked_lm_input = self._gather_indexes(sequence_data, masked_positions) - lm_data = self.dense(masked_lm_input) - lm_data = self.layer_norm(lm_data) - lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True) - logits = tf.nn.bias_add(lm_data, self.bias) - masked_positions_length = ( - masked_positions.shape.as_list()[1] or tf.shape(masked_positions)[1] - ) - return tf.reshape( - logits, [-1, masked_positions_length, self._vocab_size] - ) - - def _gather_indexes(self, sequence_tensor, positions): - """Gathers the vectors at the specific positions, for performance. - - Args: - sequence_tensor: Sequence output of shape - (`batch_size`, `seq_length`, `hidden_size`) where `hidden_size` - is number of hidden units. - positions: Positions ids of tokens in sequence to mask for - pretraining of with dimension (batch_size, num_predictions) - where `num_predictions` is maximum number of tokens to mask out - and predict per each sequence. - - Returns: - Masked out sequence tensor of shape (batch_size * num_predictions, - `hidden_size`). - """ - sequence_shape = tf.shape(sequence_tensor) - batch_size, seq_length = sequence_shape[0], sequence_shape[1] - width = sequence_tensor.shape.as_list()[2] or sequence_shape[2] - - flat_offsets = tf.reshape( - tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1] - ) - flat_positions = tf.reshape(positions + flat_offsets, [-1]) - flat_sequence_tensor = tf.reshape( - sequence_tensor, [batch_size * seq_length, width] - ) - output_tensor = tf.gather(flat_sequence_tensor, flat_positions) - - return output_tensor - - -class BertPretrainer(keras.Model): - def __init__(self, bert_model, **kwargs): - super().__init__(**kwargs) - self.bert_model = bert_model - self.masked_lm_head = MaskedLMHead( - bert_model.get_embedding_table(), - initializer=bert_model.initializer, - ) - self.next_sentence_head = keras.layers.Dense( - 2, - kernel_initializer=bert_model.initializer, - ) - self.loss_tracker = keras.metrics.Mean(name="loss") - self.lm_loss_tracker = keras.metrics.Mean(name="lm_loss") - self.nsp_loss_tracker = keras.metrics.Mean(name="nsp_loss") - self.lm_accuracy = keras.metrics.SparseCategoricalAccuracy( - name="lm_accuracy" - ) - self.nsp_accuracy = keras.metrics.SparseCategoricalAccuracy( - name="nsp_accuracy" - ) - - def call(self, data): - sequence_output, pooled_output = self.bert_model( - { - "input_ids": data["input_ids"], - "input_mask": data["input_mask"], - "segment_ids": data["segment_ids"], - } - ) - lm_preds = self.masked_lm_head( - sequence_output, data["masked_lm_positions"] - ) - nsp_preds = self.next_sentence_head(pooled_output) - return lm_preds, nsp_preds - - def train_step(self, data): - with tf.GradientTape() as tape: - lm_preds, nsp_preds = self(data, training=True) - lm_labels = data["masked_lm_ids"] - lm_weights = data["masked_lm_weights"] - nsp_labels = data["next_sentence_labels"] - - lm_loss = keras.losses.sparse_categorical_crossentropy( - lm_labels, lm_preds, from_logits=True - ) - lm_weights_summed = tf.reduce_sum(lm_weights, -1) - lm_loss = tf.reduce_sum(lm_loss * lm_weights, -1) - lm_loss = tf.math.divide_no_nan(lm_loss, lm_weights_summed) - nsp_loss = keras.losses.sparse_categorical_crossentropy( - nsp_labels, nsp_preds, from_logits=True - ) - nsp_loss = tf.reduce_mean(nsp_loss) - loss = lm_loss + nsp_loss - - # Compute gradients - trainable_vars = self.trainable_variables - gradients = tape.gradient(loss, trainable_vars) - # Update weights - self.optimizer.apply_gradients(zip(gradients, trainable_vars)) - - # Update metrics - self.loss_tracker.update_state(loss) - self.lm_loss_tracker.update_state(lm_loss) - self.nsp_loss_tracker.update_state(nsp_loss) - self.lm_accuracy.update_state(lm_labels, lm_preds, lm_weights) - self.nsp_accuracy.update_state(nsp_labels, nsp_preds) - return {m.name: m.result() for m in self.metrics} - - class LinearDecayWithWarmup(keras.optimizers.schedules.LearningRateSchedule): """ A learning rate schedule with linear warmup and decay. @@ -397,7 +232,7 @@ def main(_): with strategy.scope(): # Create a BERT model the input config. - model = BertModel( + model = BertEncoder( vocab_size=len(vocab), **model_config, ) @@ -420,7 +255,7 @@ def main(_): ) optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) - pretraining_model = BertPretrainer(model) + pretraining_model = BertLanguageModel(model) pretraining_model.compile( optimizer=optimizer, ) diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 506d071cc8..182436118f 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp import applications from keras_nlp import layers from keras_nlp import metrics from keras_nlp import tokenizers diff --git a/keras_nlp/applications/__init__.py b/keras_nlp/applications/__init__.py new file mode 100644 index 0000000000..581c4f28e0 --- /dev/null +++ b/keras_nlp/applications/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.applications.bert import BertEncoder +from keras_nlp.applications.bert import BertLanguageModel +from keras_nlp.applications.bert import BertClassifier \ No newline at end of file diff --git a/examples/bert/bert_model.py b/keras_nlp/applications/bert.py similarity index 51% rename from examples/bert/bert_model.py rename to keras_nlp/applications/bert.py index d6dd12561c..2df9907631 100644 --- a/examples/bert/bert_model.py +++ b/keras_nlp/applications/bert.py @@ -23,10 +23,11 @@ import tensorflow as tf from tensorflow import keras -import keras_nlp +# isort: off +from tensorflow.python.util.tf_export import keras_export -class BertModel(keras.Model): +class BertEncoder(keras.Model): """Bi-directional Transformer-based encoder network. This network implements a bi-directional Transformer-based encoder as @@ -198,3 +199,197 @@ def get_config(self): } ) return config + + +class MaskedLMHead(keras.layers.Layer): + """Masked language model network head for BERT. + + This layer implements a masked language model based on the provided + transformer based encoder. It assumes that the encoder network being passed + has a "get_embedding_table()" method. + + Example: + ```python + encoder=modeling.networks.BertEncoder(...) + lm_layer=MaskedLMHead(embedding_table=encoder.get_embedding_table()) + ``` + + Args: + embedding_table: The embedding table from encoder network. + inner_activation: The activation, if any, for the inner dense layer. + initializer: The initializer for the dense layer. Defaults to a Glorot + uniform initializer. + output: The output style for this layer. Can be either 'logits' or + 'predictions'. + """ + + def __init__( + self, + embedding_table, + inner_activation="gelu", + initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.embedding_table = embedding_table + self.inner_activation = keras.activations.get(inner_activation) + self.initializer = initializer + + def build(self, input_shape): + self._vocab_size, hidden_size = self.embedding_table.shape + self.dense = keras.layers.Dense( + hidden_size, + activation=self.inner_activation, + kernel_initializer=self.initializer, + name="transform/dense", + ) + self.layer_norm = keras.layers.LayerNormalization( + axis=-1, epsilon=1e-12, name="transform/LayerNorm" + ) + self.bias = self.add_weight( + "output_bias/bias", + shape=(self._vocab_size,), + initializer="zeros", + trainable=True, + ) + + super().build(input_shape) + + def call(self, sequence_data, masked_positions): + masked_lm_input = self._gather_indexes(sequence_data, masked_positions) + lm_data = self.dense(masked_lm_input) + lm_data = self.layer_norm(lm_data) + lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True) + logits = tf.nn.bias_add(lm_data, self.bias) + masked_positions_length = ( + masked_positions.shape.as_list()[1] or tf.shape(masked_positions)[1] + ) + return tf.reshape( + logits, [-1, masked_positions_length, self._vocab_size] + ) + + def _gather_indexes(self, sequence_tensor, positions): + """Gathers the vectors at the specific positions, for performance. + + Args: + sequence_tensor: Sequence output of shape + (`batch_size`, `seq_length`, `hidden_size`) where `hidden_size` + is number of hidden units. + positions: Positions ids of tokens in sequence to mask for + pretraining of with dimension (batch_size, num_predictions) + where `num_predictions` is maximum number of tokens to mask out + and predict per each sequence. + + Returns: + Masked out sequence tensor of shape (batch_size * num_predictions, + `hidden_size`). + """ + sequence_shape = tf.shape(sequence_tensor) + batch_size, seq_length = sequence_shape[0], sequence_shape[1] + width = sequence_tensor.shape.as_list()[2] or sequence_shape[2] + + flat_offsets = tf.reshape( + tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1] + ) + flat_positions = tf.reshape(positions + flat_offsets, [-1]) + flat_sequence_tensor = tf.reshape( + sequence_tensor, [batch_size * seq_length, width] + ) + output_tensor = tf.gather(flat_sequence_tensor, flat_positions) + + return output_tensor + + +class BertLanguageModel(keras.Model): + """ + MLM + NSP model with BertEncoder. + """ + + def __init__(self, encoder, **kwargs): + super().__init__(**kwargs) + self.encoder = encoder + # TODO(jbischof): replace with keras_nlp.layers.MLMHead + self.masked_lm_head = MaskedLMHead( + embedding_weights=encoder.get_embedding_table(), + initializer=encoder.initializer, + ) + self.next_sentence_head = keras.layers.Dense( + 2, + kernel_initializer=encoder.initializer, + ) + self.loss_tracker = keras.metrics.Mean(name="loss") + self.lm_loss_tracker = keras.metrics.Mean(name="lm_loss") + self.nsp_loss_tracker = keras.metrics.Mean(name="nsp_loss") + self.lm_accuracy = keras.metrics.SparseCategoricalAccuracy( + name="lm_accuracy" + ) + self.nsp_accuracy = keras.metrics.SparseCategoricalAccuracy( + name="nsp_accuracy" + ) + + def call(self, data): + sequence_output, pooled_output = self.encoder( + { + "input_ids": data["input_ids"], + "input_mask": data["input_mask"], + "segment_ids": data["segment_ids"], + } + ) + lm_preds = self.masked_lm_head( + sequence_output, data["masked_lm_positions"] + ) + nsp_preds = self.next_sentence_head(pooled_output) + return lm_preds, nsp_preds + + def train_step(self, data): + with tf.GradientTape() as tape: + lm_preds, nsp_preds = self(data, training=True) + lm_labels = data["masked_lm_ids"] + lm_weights = data["masked_lm_weights"] + nsp_labels = data["next_sentence_labels"] + + lm_loss = keras.losses.sparse_categorical_crossentropy( + lm_labels, lm_preds, from_logits=True + ) + lm_weights_summed = tf.reduce_sum(lm_weights, -1) + lm_loss = tf.reduce_sum(lm_loss * lm_weights, -1) + lm_loss = tf.math.divide_no_nan(lm_loss, lm_weights_summed) + nsp_loss = keras.losses.sparse_categorical_crossentropy( + nsp_labels, nsp_preds, from_logits=True + ) + nsp_loss = tf.reduce_mean(nsp_loss) + loss = lm_loss + nsp_loss + + # Compute gradients + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + # Update weights + self.optimizer.apply_gradients(zip(gradients, trainable_vars)) + + # Update metrics + self.loss_tracker.update_state(loss) + self.lm_loss_tracker.update_state(lm_loss) + self.nsp_loss_tracker.update_state(nsp_loss) + self.lm_accuracy.update_state(lm_labels, lm_preds, lm_weights) + self.nsp_accuracy.update_state(nsp_labels, nsp_preds) + return {m.name: m.result() for m in self.metrics} + + +class BertClassifier(keras.Model): + """Classifier model with BertEncoder + """ + + def __init__(self, encoder, num_classes, **kwargs): + super().__init__(**kwargs) + self.encoder = encoder + self.num_classes = num_classes + self._logit_layer = keras.layers.Dense( + num_classes, + kernel_initializer=encoder.initializer, + name="logits", + ) + + def call(self, inputs): + # Ignore the sequence output, use the pooled output. + _, pooled_output = self.bert_model(inputs) + return self._logit_layer(pooled_output) From d3934b7d3c6b63e64604946fc23050db5c78234f Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 5 Aug 2022 23:20:19 +0000 Subject: [PATCH 02/45] first working version --- examples/bert/bert_train.py | 15 +++++++-------- keras_nlp/applications/bert.py | 4 +++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 7d941b2fa9..73f595d068 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -28,7 +28,6 @@ from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG from examples.bert.bert_config import TRAINING_CONFIG -from examples.bert.bert_model import BertModel FLAGS = flags.FLAGS @@ -232,13 +231,13 @@ def main(_): with strategy.scope(): # Create a BERT model the input config. - model = BertEncoder( + encoder = BertEncoder( vocab_size=len(vocab), **model_config, ) # Make sure model has been called. - model(model.inputs) - model.summary() + encoder(encoder.inputs) + encoder.summary() # Allow overriding train steps from the command line for quick testing. if FLAGS.num_train_steps is not None: @@ -255,8 +254,8 @@ def main(_): ) optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) - pretraining_model = BertLanguageModel(model) - pretraining_model.compile( + language_model = BertLanguageModel(encoder) + language_model.compile( optimizer=optimizer, ) @@ -269,7 +268,7 @@ def main(_): if FLAGS.tensorboard_log_path: callbacks.append(get_tensorboard_callback()) - pretraining_model.fit( + language_model.fit( dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, @@ -278,7 +277,7 @@ def main(_): model_path = FLAGS.saved_model_output logging.info(f"Saving to {FLAGS.saved_model_output}") - model.save(model_path) + encoder.save(model_path) if __name__ == "__main__": diff --git a/keras_nlp/applications/bert.py b/keras_nlp/applications/bert.py index 2df9907631..00ebb28e6e 100644 --- a/keras_nlp/applications/bert.py +++ b/keras_nlp/applications/bert.py @@ -23,6 +23,8 @@ import tensorflow as tf from tensorflow import keras +import keras_nlp.layers + # isort: off from tensorflow.python.util.tf_export import keras_export @@ -310,7 +312,7 @@ def __init__(self, encoder, **kwargs): self.encoder = encoder # TODO(jbischof): replace with keras_nlp.layers.MLMHead self.masked_lm_head = MaskedLMHead( - embedding_weights=encoder.get_embedding_table(), + embedding_table=encoder.get_embedding_table(), initializer=encoder.initializer, ) self.next_sentence_head = keras.layers.Dense( From 5b8572a96e42cf323016a480d6f9880170db1359 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Tue, 9 Aug 2022 00:45:07 +0000 Subject: [PATCH 03/45] Partial draft of functional API + BertBase --- examples/bert/bert_preprocess.py | 2 + examples/bert/bert_train.py | 6 +- keras_nlp/applications/__init__.py | 2 +- keras_nlp/applications/bert.py | 281 +++++++++++++---------------- 4 files changed, 133 insertions(+), 158 deletions(-) diff --git a/examples/bert/bert_preprocess.py b/examples/bert/bert_preprocess.py index 36b6bc0e27..8436bfa432 100644 --- a/examples/bert/bert_preprocess.py +++ b/examples/bert/bert_preprocess.py @@ -375,6 +375,8 @@ def create_masked_lm_predictions( ): """Creates the predictions for the masked LM objective.""" + # TODO(bischof): replace with keras_nlp.layers.MLMMaskGenerator + cand_indexes = [] for (i, token) in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 73f595d068..114f6647c8 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -21,13 +21,11 @@ from absl import logging from tensorflow import keras -from keras_nlp.applications.bert import ( - BertLanguageModel, - BertEncoder, -) from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG from examples.bert.bert_config import TRAINING_CONFIG +from keras_nlp.applications.bert import BertEncoder +from keras_nlp.applications.bert import BertLanguageModel FLAGS = flags.FLAGS diff --git a/keras_nlp/applications/__init__.py b/keras_nlp/applications/__init__.py index 581c4f28e0..b8392e5090 100644 --- a/keras_nlp/applications/__init__.py +++ b/keras_nlp/applications/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.applications.bert import BertClassifier from keras_nlp.applications.bert import BertEncoder from keras_nlp.applications.bert import BertLanguageModel -from keras_nlp.applications.bert import BertClassifier \ No newline at end of file diff --git a/keras_nlp/applications/bert.py b/keras_nlp/applications/bert.py index 00ebb28e6e..4fe3820366 100644 --- a/keras_nlp/applications/bert.py +++ b/keras_nlp/applications/bert.py @@ -11,14 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Bert model and layer implementations. - -These components come from the tensorflow official model repository for BERT: -https://github.com/tensorflow/models/tree/master/official/nlp/modeling - -This is to get us into a testable state. We should work to replace all of these -components with components from the keras-nlp library. -""" +"""Bert model and layer implementations.""" import tensorflow as tf from tensorflow import keras @@ -26,10 +19,25 @@ import keras_nlp.layers # isort: off +# TODO(bischof): decide what to export or whether we are using these decorators from tensorflow.python.util.tf_export import keras_export - -class BertEncoder(keras.Model): +CLS_INDEX = 0 +TOKEN_EMBEDDING_LAYER_NAME = "token_embedding" + + +def BertEncoder( + vocab_size, + num_layers=12, + hidden_size=768, + dropout=0.1, + num_attention_heads=12, + inner_size=3072, + inner_activation="gelu", + initializer_range=0.02, + max_sequence_length=512, + type_vocab_size=2, +): """Bi-directional Transformer-based encoder network. This network implements a bi-directional Transformer-based encoder as @@ -66,141 +74,86 @@ class BertEncoder(keras.Model): dense layers is normalized. """ - def __init__( - self, - vocab_size, - num_layers=12, - hidden_size=768, - dropout=0.1, - num_attention_heads=12, - inner_size=3072, - inner_activation="gelu", - initializer_range=0.02, - max_sequence_length=512, - type_vocab_size=2, - **kwargs, - ): - super().__init__(**kwargs) - - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_layers = num_layers - self.num_attention_heads = num_attention_heads - self.max_sequence_length = max_sequence_length - self.type_vocab_size = type_vocab_size - self.inner_size = inner_size - self.inner_activation = keras.activations.get(inner_activation) - self.initializer_range = initializer_range - self.initializer = keras.initializers.TruncatedNormal( - stddev=initializer_range - ) - self.dropout = dropout - - self._embedding_layer = keras.layers.Embedding( - input_dim=vocab_size, - output_dim=hidden_size, - embeddings_initializer=self.initializer, - name="word_embeddings", - ) - - self._position_embedding_layer = keras_nlp.layers.PositionEmbedding( - initializer=self.initializer, - sequence_length=max_sequence_length, - name="position_embedding", - ) - - self._type_embedding_layer = keras.layers.Embedding( - input_dim=type_vocab_size, - output_dim=hidden_size, - embeddings_initializer=self.initializer, - name="type_embeddings", - ) - - self._embedding_norm_layer = keras.layers.LayerNormalization( - name="embeddings/layer_norm", - axis=-1, - epsilon=1e-12, - dtype=tf.float32, - ) - - self._embedding_dropout = keras.layers.Dropout( - rate=dropout, name="embedding_dropout" - ) - - self._transformer_layers = [] - for i in range(num_layers): - layer = keras_nlp.layers.TransformerEncoder( - num_heads=num_attention_heads, - intermediate_dim=inner_size, - activation=self.inner_activation, - dropout=dropout, - kernel_initializer=self.initializer, - name="transformer/layer_%d" % i, - ) - self._transformer_layers.append(layer) - - # This is used as the intermediate output for the NSP prediction head. - # It is important we include this in the mode, as we want to preserve - # these weights for fine-tuning tasks. - self._pooler_layer = keras.layers.Dense( - units=hidden_size, - activation="tanh", - kernel_initializer=self.initializer, - name="pooler_dense", - ) - - self.inputs = dict( - input_ids=keras.Input(shape=(None,), dtype=tf.int32), - input_mask=keras.Input(shape=(None,), dtype=tf.int32), - segment_ids=keras.Input(shape=(None,), dtype=tf.int32), - ) - - def call(self, inputs): - if isinstance(inputs, dict): - input_ids = inputs.get("input_ids") - input_mask = inputs.get("input_mask") - segment_ids = inputs.get("segment_ids") - else: - raise ValueError(f"Inputs should be a dict. Received: {inputs}.") - - word_embeddings = None - word_embeddings = self._embedding_layer(input_ids) - position_embeddings = self._position_embedding_layer(word_embeddings) - type_embeddings = self._type_embedding_layer(segment_ids) - - embeddings = word_embeddings + position_embeddings + type_embeddings - embeddings = self._embedding_norm_layer(embeddings) - embeddings = self._embedding_dropout(embeddings) - - x = embeddings - for layer in self._transformer_layers: - x = layer(x, padding_mask=input_mask) - sequence_output = x - pooled_output = self._pooler_layer(x[:, 0, :]) # 0 is the [CLS] token. - return sequence_output, pooled_output - - def get_embedding_table(self): - return self._embedding_layer.embeddings - - def get_config(self): - config = super().get_config() - config.update( - { - "vocab_size": self.vocab_size, - "hidden_size": self.hidden_size, - "num_layers": self.num_layers, - "num_attention_heads": self.num_attention_heads, - "max_sequence_length": self.max_sequence_length, - "type_vocab_size": self.type_vocab_size, - "inner_size": self.inner_size, - "inner_activation": keras.activations.serialize( - self.inner_activation - ), - "dropout": self.dropout, - "initializer_range": self.initializer_range, - } - ) - return config + # Create lambda functions from input params + inner_activation_fn = keras.activations.get(inner_activation) + initializer_fn = keras.initializers.TruncatedNormal( + stddev=initializer_range + ) + + # Functional version of model + token_id_input = keras.Input(shape=(None,), dtype="int32", name="input_ids") + segment_id_input = keras.Input( + shape=(None,), dtype="int32", name="segment_ids" + ) + input_mask = keras.Input(shape=(None,), dtype="int32", name="input_mask") + + # Embed tokens, positions, and segment ids. + token_embedding = keras.layers.Embedding( + input_dim=vocab_size, + output_dim=hidden_size, + name=TOKEN_EMBEDDING_LAYER_NAME, + )(token_id_input) + position_embedding = keras_nlp.layers.PositionEmbedding( + initializer=initializer_fn, + sequence_length=max_sequence_length, + name="position_embedding", + )(token_embedding) + segment_embedding = keras.layers.Embedding( + input_dim=type_vocab_size, + output_dim=hidden_size, + name="segment_embedding", + )(segment_id_input) + + # Sum, normailze and apply dropout to embeddings. + x = keras.layers.Add( + name="embedding_sum", + )((token_embedding, position_embedding, segment_embedding)) + x = keras.layers.LayerNormalization( + name="embeddings/layer_norm", + axis=-1, + epsilon=1e-12, + dtype=tf.float32, + )(x) + x = keras.layers.Dropout( + dropout, + name="embedding_dropout", + )(x) + + # Apply successive transformer encoder blocks. + for i in range(num_layers): + x = keras_nlp.layers.TransformerEncoder( + num_heads=num_attention_heads, + intermediate_dim=inner_size, + activation=inner_activation_fn, + dropout=dropout, + kernel_initializer=initializer_fn, + name="transformer/layer_%d" % i, + )(x, padding_mask=input_mask) + + # Construct the two BERT outputs, and apply a dense to the pooled output. + sequence_output = x + pooled_output = keras.layers.Dense( + hidden_size, + activation="tanh", + name="pooled_dense", + )(x[:, CLS_INDEX, :]) + + model = keras.Model( + inputs={ + "input_ids": token_id_input, + "segment_ids": segment_id_input, + "input_mask": input_mask, + }, + outputs={ + "sequence_output": sequence_output, + "pooled_output": pooled_output, + }, + ) + # Save some metadata for downstream usage + model.initializer_fn = initializer_fn + model.type_vocab_size = type_vocab_size + model.max_sequence_length = max_sequence_length + return model class MaskedLMHead(keras.layers.Layer): @@ -312,12 +265,14 @@ def __init__(self, encoder, **kwargs): self.encoder = encoder # TODO(jbischof): replace with keras_nlp.layers.MLMHead self.masked_lm_head = MaskedLMHead( - embedding_table=encoder.get_embedding_table(), - initializer=encoder.initializer, + embedding_table=encoder.get_layer( + TOKEN_EMBEDDING_LAYER_NAME + ).embeddings, + initializer=encoder.initializer_fn, ) self.next_sentence_head = keras.layers.Dense( - 2, - kernel_initializer=encoder.initializer, + encoder.type_vocab_size, + kernel_initializer=encoder.initializer_fn, ) self.loss_tracker = keras.metrics.Mean(name="loss") self.lm_loss_tracker = keras.metrics.Mean(name="lm_loss") @@ -378,8 +333,7 @@ def train_step(self, data): class BertClassifier(keras.Model): - """Classifier model with BertEncoder - """ + """Classifier model with BertEncoder.""" def __init__(self, encoder, num_classes, **kwargs): super().__init__(**kwargs) @@ -387,11 +341,32 @@ def __init__(self, encoder, num_classes, **kwargs): self.num_classes = num_classes self._logit_layer = keras.layers.Dense( num_classes, - kernel_initializer=encoder.initializer, + kernel_initializer=encoder.initializer_fn, name="logits", ) def call(self, inputs): # Ignore the sequence output, use the pooled output. _, pooled_output = self.bert_model(inputs) - return self._logit_layer(pooled_output) + return self._logit_layer(pooled_output) + + +def BertBaseEncoder(weights=None): + """Factory for BertEncoder using "Base" architecture.""" + + model = BertEncoder( + vocab_size=30522, + num_layers=12, + hidden_size=768, + dropout=0.1, + num_attention_heads=12, + inner_size=3072, + inner_activation="gelu", + initializer_range=0.02, + ) + + if weights is not None: + model.load_weights(weights) + + # TODO(bischof): attach the tokenizer + return model From 7add3a7c2cd34db4f83873544ebc690a855d3b22 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Tue, 9 Aug 2022 23:24:23 +0000 Subject: [PATCH 04/45] Change Bert to Model subclass API --- examples/bert/bert_train.py | 7 +- keras_nlp/applications/__init__.py | 3 +- keras_nlp/applications/bert.py | 234 +++++++++++++++++------------ 3 files changed, 139 insertions(+), 105 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 114f6647c8..35b1c7d0bc 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -24,7 +24,7 @@ from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG from examples.bert.bert_config import TRAINING_CONFIG -from keras_nlp.applications.bert import BertEncoder +from keras_nlp.applications.bert import Bert from keras_nlp.applications.bert import BertLanguageModel FLAGS = flags.FLAGS @@ -229,10 +229,7 @@ def main(_): with strategy.scope(): # Create a BERT model the input config. - encoder = BertEncoder( - vocab_size=len(vocab), - **model_config, - ) + encoder = Bert(vocab_size=len(vocab), **model_config) # Make sure model has been called. encoder(encoder.inputs) encoder.summary() diff --git a/keras_nlp/applications/__init__.py b/keras_nlp/applications/__init__.py index b8392e5090..91901dd5a6 100644 --- a/keras_nlp/applications/__init__.py +++ b/keras_nlp/applications/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from keras_nlp.applications.bert import BertClassifier -from keras_nlp.applications.bert import BertEncoder +from keras_nlp.applications.bert import Bert +from keras_nlp.applications.bert import BertBase from keras_nlp.applications.bert import BertLanguageModel diff --git a/keras_nlp/applications/bert.py b/keras_nlp/applications/bert.py index 4fe3820366..3b3fea8d65 100644 --- a/keras_nlp/applications/bert.py +++ b/keras_nlp/applications/bert.py @@ -26,18 +26,7 @@ TOKEN_EMBEDDING_LAYER_NAME = "token_embedding" -def BertEncoder( - vocab_size, - num_layers=12, - hidden_size=768, - dropout=0.1, - num_attention_heads=12, - inner_size=3072, - inner_activation="gelu", - initializer_range=0.02, - max_sequence_length=512, - type_vocab_size=2, -): +class Bert(keras.Model): """Bi-directional Transformer-based encoder network. This network implements a bi-directional Transformer-based encoder as @@ -54,7 +43,6 @@ def BertEncoder( vocab_size: The size of the token vocabulary. num_layers: The number of transformer layers. hidden_size: The size of the transformer hidden layers. - dropout: Dropout probability for the Transformer encoder. num_attention_heads: The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads. inner_size: The output dimension of the first Dense layer in a two-layer @@ -63,6 +51,7 @@ def BertEncoder( two-layer feedforward network for each transformer. initializer_range: The initialzer range to use for a truncated normal initializer. + dropout: Dropout probability for the Transformer encoder. max_sequence_length: The maximum sequence length that this encoder can consume. If None, max_sequence_length uses the value from sequence length. This determines the variable shape for positional @@ -73,87 +62,133 @@ def BertEncoder( dense layers. If set False, output of attention and intermediate dense layers is normalized. """ + def __init__( + self, + vocab_size, + num_layers, + hidden_size, + num_attention_heads, + inner_size, + inner_activation="gelu", + initializer_range=0.02, + dropout=0.1, + max_sequence_length=512, + type_vocab_size=2, + **kwargs + ): - # Create lambda functions from input params - inner_activation_fn = keras.activations.get(inner_activation) - initializer_fn = keras.initializers.TruncatedNormal( - stddev=initializer_range - ) + # Create lambda functions from input params + inner_activation_fn = keras.activations.get(inner_activation) + initializer_fn = keras.initializers.TruncatedNormal( + stddev=initializer_range + ) - # Functional version of model - token_id_input = keras.Input(shape=(None,), dtype="int32", name="input_ids") - segment_id_input = keras.Input( - shape=(None,), dtype="int32", name="segment_ids" - ) - input_mask = keras.Input(shape=(None,), dtype="int32", name="input_mask") - - # Embed tokens, positions, and segment ids. - token_embedding = keras.layers.Embedding( - input_dim=vocab_size, - output_dim=hidden_size, - name=TOKEN_EMBEDDING_LAYER_NAME, - )(token_id_input) - position_embedding = keras_nlp.layers.PositionEmbedding( - initializer=initializer_fn, - sequence_length=max_sequence_length, - name="position_embedding", - )(token_embedding) - segment_embedding = keras.layers.Embedding( - input_dim=type_vocab_size, - output_dim=hidden_size, - name="segment_embedding", - )(segment_id_input) - - # Sum, normailze and apply dropout to embeddings. - x = keras.layers.Add( - name="embedding_sum", - )((token_embedding, position_embedding, segment_embedding)) - x = keras.layers.LayerNormalization( - name="embeddings/layer_norm", - axis=-1, - epsilon=1e-12, - dtype=tf.float32, - )(x) - x = keras.layers.Dropout( - dropout, - name="embedding_dropout", - )(x) - - # Apply successive transformer encoder blocks. - for i in range(num_layers): - x = keras_nlp.layers.TransformerEncoder( - num_heads=num_attention_heads, - intermediate_dim=inner_size, - activation=inner_activation_fn, - dropout=dropout, - kernel_initializer=initializer_fn, - name="transformer/layer_%d" % i, - )(x, padding_mask=input_mask) - - # Construct the two BERT outputs, and apply a dense to the pooled output. - sequence_output = x - pooled_output = keras.layers.Dense( - hidden_size, - activation="tanh", - name="pooled_dense", - )(x[:, CLS_INDEX, :]) - - model = keras.Model( - inputs={ - "input_ids": token_id_input, - "segment_ids": segment_id_input, - "input_mask": input_mask, - }, - outputs={ - "sequence_output": sequence_output, - "pooled_output": pooled_output, - }, - ) - # Save some metadata for downstream usage - model.initializer_fn = initializer_fn - model.type_vocab_size = type_vocab_size - model.max_sequence_length = max_sequence_length - return model + # Functional version of model + token_id_input = keras.Input(shape=(None,), dtype="int32", name="input_ids") + segment_id_input = keras.Input( + shape=(None,), dtype="int32", name="segment_ids" + ) + input_mask = keras.Input(shape=(None,), dtype="int32", name="input_mask") + + # Embed tokens, positions, and segment ids. + token_embedding = keras.layers.Embedding( + input_dim=vocab_size, + output_dim=hidden_size, + name=TOKEN_EMBEDDING_LAYER_NAME, + )(token_id_input) + position_embedding = keras_nlp.layers.PositionEmbedding( + initializer=initializer_fn, + sequence_length=max_sequence_length, + name="position_embedding", + )(token_embedding) + segment_embedding = keras.layers.Embedding( + input_dim=type_vocab_size, + output_dim=hidden_size, + name="segment_embedding", + )(segment_id_input) + + # Sum, normailze and apply dropout to embeddings. + x = keras.layers.Add( + name="embedding_sum", + )((token_embedding, position_embedding, segment_embedding)) + x = keras.layers.LayerNormalization( + name="embeddings/layer_norm", + axis=-1, + epsilon=1e-12, + dtype=tf.float32, + )(x) + x = keras.layers.Dropout( + dropout, + name="embedding_dropout", + )(x) + + # Apply successive transformer encoder blocks. + for i in range(num_layers): + x = keras_nlp.layers.TransformerEncoder( + num_heads=num_attention_heads, + intermediate_dim=inner_size, + activation=inner_activation_fn, + dropout=dropout, + kernel_initializer=initializer_fn, + name="transformer/layer_%d" % i, + )(x, padding_mask=input_mask) + + # Construct the two BERT outputs, and apply a dense to the pooled output. + sequence_output = x + pooled_output = keras.layers.Dense( + hidden_size, + activation="tanh", + name="pooled_dense", + )(x[:, CLS_INDEX, :]) + + # Instantiate using Functional API Model constructor + super(Bert, self).__init__( + inputs={ + "input_ids": token_id_input, + "segment_ids": segment_id_input, + "input_mask": input_mask, + }, + outputs={ + "sequence_output": sequence_output, + "pooled_output": pooled_output, + }, + **kwargs) + # All references to `self` below this line + self.inner_activation_fn = inner_activation_fn + self.initializer_fn = initializer_fn + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.max_sequence_length = max_sequence_length + self.type_vocab_size = type_vocab_size + self.inner_size = inner_size + self.inner_activation = keras.activations.get(inner_activation) + self.initializer_range = initializer_range + self.dropout = dropout + + def get_embedding_table(self): + return self.get_layer(TOKEN_EMBEDDING_LAYER_NAME).embeddings + + def get_config(self): + config = super().get_config() + config.update( + { + "vocab_size": self.vocab_size, + "hidden_size": self.hidden_size, + "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + "max_sequence_length": self.max_sequence_length, + "type_vocab_size": self.type_vocab_size, + "inner_size": self.inner_size, + "inner_activation": keras.activations.serialize( + self.inner_activation + ), + "dropout": self.dropout, + "initializer_range": self.initializer_range, + } + ) + return config class MaskedLMHead(keras.layers.Layer): @@ -265,9 +300,7 @@ def __init__(self, encoder, **kwargs): self.encoder = encoder # TODO(jbischof): replace with keras_nlp.layers.MLMHead self.masked_lm_head = MaskedLMHead( - embedding_table=encoder.get_layer( - TOKEN_EMBEDDING_LAYER_NAME - ).embeddings, + embedding_table=encoder.get_embedding_table(), initializer=encoder.initializer_fn, ) self.next_sentence_head = keras.layers.Dense( @@ -285,13 +318,14 @@ def __init__(self, encoder, **kwargs): ) def call(self, data): - sequence_output, pooled_output = self.encoder( + encoder_output = self.encoder( { "input_ids": data["input_ids"], "input_mask": data["input_mask"], "segment_ids": data["segment_ids"], } ) + sequence_output, pooled_output = encoder_output["sequence_output"], encoder_output["pooled_output"] lm_preds = self.masked_lm_head( sequence_output, data["masked_lm_positions"] ) @@ -351,20 +385,22 @@ def call(self, inputs): return self._logit_layer(pooled_output) -def BertBaseEncoder(weights=None): +def BertBase(weights=None): """Factory for BertEncoder using "Base" architecture.""" - model = BertEncoder( + model = Bert( vocab_size=30522, num_layers=12, hidden_size=768, - dropout=0.1, num_attention_heads=12, inner_size=3072, inner_activation="gelu", initializer_range=0.02, + dropout=0.1, ) + # TODO(bischof): add some documentation or magic to load our checkpoints + # Note: This is pure Keras and also intended to work with user checkpoints if weights is not None: model.load_weights(weights) From 2cfa58cbe2bb4f30b5607ddff1b612600aa2ed94 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 10 Aug 2022 02:32:40 +0000 Subject: [PATCH 05/45] Get fine-tuning script working --- examples/bert/bert_finetune_glue.py | 23 +++-------------------- keras_nlp/applications/bert.py | 8 ++++---- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/examples/bert/bert_finetune_glue.py b/examples/bert/bert_finetune_glue.py index 8f649bee81..f73c7e0e4b 100644 --- a/examples/bert/bert_finetune_glue.py +++ b/examples/bert/bert_finetune_glue.py @@ -26,6 +26,7 @@ from examples.bert.bert_config import FINETUNING_CONFIG from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG +from keras_nlp.applications import BertClassifier FLAGS = flags.FLAGS @@ -109,24 +110,6 @@ def to_tf_dataset(split): return train_ds, test_ds, validation_ds -class BertClassificationFinetuner(keras.Model): - """Adds a classification head to a pre-trained BERT model for finetuning""" - - def __init__(self, bert_model, num_classes, initializer, **kwargs): - super().__init__(**kwargs) - self.bert_model = bert_model - self._logit_layer = keras.layers.Dense( - num_classes, - kernel_initializer=initializer, - name="logits", - ) - - def call(self, inputs): - # Ignore the sequence output, use the pooled output. - _, pooled_output = self.bert_model(inputs) - return self._logit_layer(pooled_output) - - class BertHyperModel(keras_tuner.HyperModel): """Creates a hypermodel to help with the search space for finetuning.""" @@ -136,8 +119,8 @@ def __init__(self, model_config): def build(self, hp): model = keras.models.load_model(FLAGS.saved_model_input, compile=False) model_config = self.model_config - finetuning_model = BertClassificationFinetuner( - bert_model=model, + finetuning_model = BertClassifier( + encoder=model, num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2, initializer=keras.initializers.TruncatedNormal( stddev=model_config["initializer_range"] diff --git a/keras_nlp/applications/bert.py b/keras_nlp/applications/bert.py index 3b3fea8d65..8d35302ff6 100644 --- a/keras_nlp/applications/bert.py +++ b/keras_nlp/applications/bert.py @@ -20,7 +20,7 @@ # isort: off # TODO(bischof): decide what to export or whether we are using these decorators -from tensorflow.python.util.tf_export import keras_export +#from tensorflow.python.util.tf_export import keras_export CLS_INDEX = 0 TOKEN_EMBEDDING_LAYER_NAME = "token_embedding" @@ -369,19 +369,19 @@ def train_step(self, data): class BertClassifier(keras.Model): """Classifier model with BertEncoder.""" - def __init__(self, encoder, num_classes, **kwargs): + def __init__(self, encoder, num_classes, initializer, **kwargs): super().__init__(**kwargs) self.encoder = encoder self.num_classes = num_classes self._logit_layer = keras.layers.Dense( num_classes, - kernel_initializer=encoder.initializer_fn, + kernel_initializer=initializer, name="logits", ) def call(self, inputs): # Ignore the sequence output, use the pooled output. - _, pooled_output = self.bert_model(inputs) + pooled_output = self.encoder(inputs)["pooled_output"] return self._logit_layer(pooled_output) From 9eeff1a9bd78e70445a30587761bae4bea813867 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 10 Aug 2022 02:37:20 +0000 Subject: [PATCH 06/45] Move pretraining head back to `examples/` --- examples/bert/bert_train.py | 174 +++++++++++++++++++++++++++- keras_nlp/applications/__init__.py | 1 - keras_nlp/applications/bert.py | 175 ----------------------------- 3 files changed, 173 insertions(+), 177 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 35b1c7d0bc..668bd84576 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -25,7 +25,6 @@ from examples.bert.bert_config import PREPROCESSING_CONFIG from examples.bert.bert_config import TRAINING_CONFIG from keras_nlp.applications.bert import Bert -from keras_nlp.applications.bert import BertLanguageModel FLAGS = flags.FLAGS @@ -90,6 +89,179 @@ "Override the pre-configured number of train steps..", ) +class MaskedLMHead(keras.layers.Layer): + """Masked language model network head for BERT. + + This layer implements a masked language model based on the provided + transformer based encoder. It assumes that the encoder network being passed + has a "get_embedding_table()" method. + + Example: + ```python + encoder=modeling.networks.BertEncoder(...) + lm_layer=MaskedLMHead(embedding_table=encoder.get_embedding_table()) + ``` + + Args: + embedding_table: The embedding table from encoder network. + inner_activation: The activation, if any, for the inner dense layer. + initializer: The initializer for the dense layer. Defaults to a Glorot + uniform initializer. + output: The output style for this layer. Can be either 'logits' or + 'predictions'. + """ + + def __init__( + self, + embedding_table, + inner_activation="gelu", + initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.embedding_table = embedding_table + self.inner_activation = keras.activations.get(inner_activation) + self.initializer = initializer + + def build(self, input_shape): + self._vocab_size, hidden_size = self.embedding_table.shape + self.dense = keras.layers.Dense( + hidden_size, + activation=self.inner_activation, + kernel_initializer=self.initializer, + name="transform/dense", + ) + self.layer_norm = keras.layers.LayerNormalization( + axis=-1, epsilon=1e-12, name="transform/LayerNorm" + ) + self.bias = self.add_weight( + "output_bias/bias", + shape=(self._vocab_size,), + initializer="zeros", + trainable=True, + ) + + super().build(input_shape) + + def call(self, sequence_data, masked_positions): + masked_lm_input = self._gather_indexes(sequence_data, masked_positions) + lm_data = self.dense(masked_lm_input) + lm_data = self.layer_norm(lm_data) + lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True) + logits = tf.nn.bias_add(lm_data, self.bias) + masked_positions_length = ( + masked_positions.shape.as_list()[1] or tf.shape(masked_positions)[1] + ) + return tf.reshape( + logits, [-1, masked_positions_length, self._vocab_size] + ) + + def _gather_indexes(self, sequence_tensor, positions): + """Gathers the vectors at the specific positions, for performance. + + Args: + sequence_tensor: Sequence output of shape + (`batch_size`, `seq_length`, `hidden_size`) where `hidden_size` + is number of hidden units. + positions: Positions ids of tokens in sequence to mask for + pretraining of with dimension (batch_size, num_predictions) + where `num_predictions` is maximum number of tokens to mask out + and predict per each sequence. + + Returns: + Masked out sequence tensor of shape (batch_size * num_predictions, + `hidden_size`). + """ + sequence_shape = tf.shape(sequence_tensor) + batch_size, seq_length = sequence_shape[0], sequence_shape[1] + width = sequence_tensor.shape.as_list()[2] or sequence_shape[2] + + flat_offsets = tf.reshape( + tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1] + ) + flat_positions = tf.reshape(positions + flat_offsets, [-1]) + flat_sequence_tensor = tf.reshape( + sequence_tensor, [batch_size * seq_length, width] + ) + output_tensor = tf.gather(flat_sequence_tensor, flat_positions) + + return output_tensor + + +class BertLanguageModel(keras.Model): + """ + MLM + NSP model with BertEncoder. + """ + + def __init__(self, encoder, **kwargs): + super().__init__(**kwargs) + self.encoder = encoder + # TODO(jbischof): replace with keras_nlp.layers.MLMHead + self.masked_lm_head = MaskedLMHead( + embedding_table=encoder.get_embedding_table(), + initializer=encoder.initializer_fn, + ) + self.next_sentence_head = keras.layers.Dense( + encoder.type_vocab_size, + kernel_initializer=encoder.initializer_fn, + ) + self.loss_tracker = keras.metrics.Mean(name="loss") + self.lm_loss_tracker = keras.metrics.Mean(name="lm_loss") + self.nsp_loss_tracker = keras.metrics.Mean(name="nsp_loss") + self.lm_accuracy = keras.metrics.SparseCategoricalAccuracy( + name="lm_accuracy" + ) + self.nsp_accuracy = keras.metrics.SparseCategoricalAccuracy( + name="nsp_accuracy" + ) + + def call(self, data): + encoder_output = self.encoder( + { + "input_ids": data["input_ids"], + "input_mask": data["input_mask"], + "segment_ids": data["segment_ids"], + } + ) + sequence_output, pooled_output = encoder_output["sequence_output"], encoder_output["pooled_output"] + lm_preds = self.masked_lm_head( + sequence_output, data["masked_lm_positions"] + ) + nsp_preds = self.next_sentence_head(pooled_output) + return lm_preds, nsp_preds + + def train_step(self, data): + with tf.GradientTape() as tape: + lm_preds, nsp_preds = self(data, training=True) + lm_labels = data["masked_lm_ids"] + lm_weights = data["masked_lm_weights"] + nsp_labels = data["next_sentence_labels"] + + lm_loss = keras.losses.sparse_categorical_crossentropy( + lm_labels, lm_preds, from_logits=True + ) + lm_weights_summed = tf.reduce_sum(lm_weights, -1) + lm_loss = tf.reduce_sum(lm_loss * lm_weights, -1) + lm_loss = tf.math.divide_no_nan(lm_loss, lm_weights_summed) + nsp_loss = keras.losses.sparse_categorical_crossentropy( + nsp_labels, nsp_preds, from_logits=True + ) + nsp_loss = tf.reduce_mean(nsp_loss) + loss = lm_loss + nsp_loss + + # Compute gradients + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + # Update weights + self.optimizer.apply_gradients(zip(gradients, trainable_vars)) + + # Update metrics + self.loss_tracker.update_state(loss) + self.lm_loss_tracker.update_state(lm_loss) + self.nsp_loss_tracker.update_state(nsp_loss) + self.lm_accuracy.update_state(lm_labels, lm_preds, lm_weights) + self.nsp_accuracy.update_state(nsp_labels, nsp_preds) + return {m.name: m.result() for m in self.metrics} class LinearDecayWithWarmup(keras.optimizers.schedules.LearningRateSchedule): """ diff --git a/keras_nlp/applications/__init__.py b/keras_nlp/applications/__init__.py index 91901dd5a6..7b7fefd9da 100644 --- a/keras_nlp/applications/__init__.py +++ b/keras_nlp/applications/__init__.py @@ -15,4 +15,3 @@ from keras_nlp.applications.bert import BertClassifier from keras_nlp.applications.bert import Bert from keras_nlp.applications.bert import BertBase -from keras_nlp.applications.bert import BertLanguageModel diff --git a/keras_nlp/applications/bert.py b/keras_nlp/applications/bert.py index 8d35302ff6..32a1757150 100644 --- a/keras_nlp/applications/bert.py +++ b/keras_nlp/applications/bert.py @@ -191,181 +191,6 @@ def get_config(self): return config -class MaskedLMHead(keras.layers.Layer): - """Masked language model network head for BERT. - - This layer implements a masked language model based on the provided - transformer based encoder. It assumes that the encoder network being passed - has a "get_embedding_table()" method. - - Example: - ```python - encoder=modeling.networks.BertEncoder(...) - lm_layer=MaskedLMHead(embedding_table=encoder.get_embedding_table()) - ``` - - Args: - embedding_table: The embedding table from encoder network. - inner_activation: The activation, if any, for the inner dense layer. - initializer: The initializer for the dense layer. Defaults to a Glorot - uniform initializer. - output: The output style for this layer. Can be either 'logits' or - 'predictions'. - """ - - def __init__( - self, - embedding_table, - inner_activation="gelu", - initializer="glorot_uniform", - **kwargs, - ): - super().__init__(**kwargs) - self.embedding_table = embedding_table - self.inner_activation = keras.activations.get(inner_activation) - self.initializer = initializer - - def build(self, input_shape): - self._vocab_size, hidden_size = self.embedding_table.shape - self.dense = keras.layers.Dense( - hidden_size, - activation=self.inner_activation, - kernel_initializer=self.initializer, - name="transform/dense", - ) - self.layer_norm = keras.layers.LayerNormalization( - axis=-1, epsilon=1e-12, name="transform/LayerNorm" - ) - self.bias = self.add_weight( - "output_bias/bias", - shape=(self._vocab_size,), - initializer="zeros", - trainable=True, - ) - - super().build(input_shape) - - def call(self, sequence_data, masked_positions): - masked_lm_input = self._gather_indexes(sequence_data, masked_positions) - lm_data = self.dense(masked_lm_input) - lm_data = self.layer_norm(lm_data) - lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True) - logits = tf.nn.bias_add(lm_data, self.bias) - masked_positions_length = ( - masked_positions.shape.as_list()[1] or tf.shape(masked_positions)[1] - ) - return tf.reshape( - logits, [-1, masked_positions_length, self._vocab_size] - ) - - def _gather_indexes(self, sequence_tensor, positions): - """Gathers the vectors at the specific positions, for performance. - - Args: - sequence_tensor: Sequence output of shape - (`batch_size`, `seq_length`, `hidden_size`) where `hidden_size` - is number of hidden units. - positions: Positions ids of tokens in sequence to mask for - pretraining of with dimension (batch_size, num_predictions) - where `num_predictions` is maximum number of tokens to mask out - and predict per each sequence. - - Returns: - Masked out sequence tensor of shape (batch_size * num_predictions, - `hidden_size`). - """ - sequence_shape = tf.shape(sequence_tensor) - batch_size, seq_length = sequence_shape[0], sequence_shape[1] - width = sequence_tensor.shape.as_list()[2] or sequence_shape[2] - - flat_offsets = tf.reshape( - tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1] - ) - flat_positions = tf.reshape(positions + flat_offsets, [-1]) - flat_sequence_tensor = tf.reshape( - sequence_tensor, [batch_size * seq_length, width] - ) - output_tensor = tf.gather(flat_sequence_tensor, flat_positions) - - return output_tensor - - -class BertLanguageModel(keras.Model): - """ - MLM + NSP model with BertEncoder. - """ - - def __init__(self, encoder, **kwargs): - super().__init__(**kwargs) - self.encoder = encoder - # TODO(jbischof): replace with keras_nlp.layers.MLMHead - self.masked_lm_head = MaskedLMHead( - embedding_table=encoder.get_embedding_table(), - initializer=encoder.initializer_fn, - ) - self.next_sentence_head = keras.layers.Dense( - encoder.type_vocab_size, - kernel_initializer=encoder.initializer_fn, - ) - self.loss_tracker = keras.metrics.Mean(name="loss") - self.lm_loss_tracker = keras.metrics.Mean(name="lm_loss") - self.nsp_loss_tracker = keras.metrics.Mean(name="nsp_loss") - self.lm_accuracy = keras.metrics.SparseCategoricalAccuracy( - name="lm_accuracy" - ) - self.nsp_accuracy = keras.metrics.SparseCategoricalAccuracy( - name="nsp_accuracy" - ) - - def call(self, data): - encoder_output = self.encoder( - { - "input_ids": data["input_ids"], - "input_mask": data["input_mask"], - "segment_ids": data["segment_ids"], - } - ) - sequence_output, pooled_output = encoder_output["sequence_output"], encoder_output["pooled_output"] - lm_preds = self.masked_lm_head( - sequence_output, data["masked_lm_positions"] - ) - nsp_preds = self.next_sentence_head(pooled_output) - return lm_preds, nsp_preds - - def train_step(self, data): - with tf.GradientTape() as tape: - lm_preds, nsp_preds = self(data, training=True) - lm_labels = data["masked_lm_ids"] - lm_weights = data["masked_lm_weights"] - nsp_labels = data["next_sentence_labels"] - - lm_loss = keras.losses.sparse_categorical_crossentropy( - lm_labels, lm_preds, from_logits=True - ) - lm_weights_summed = tf.reduce_sum(lm_weights, -1) - lm_loss = tf.reduce_sum(lm_loss * lm_weights, -1) - lm_loss = tf.math.divide_no_nan(lm_loss, lm_weights_summed) - nsp_loss = keras.losses.sparse_categorical_crossentropy( - nsp_labels, nsp_preds, from_logits=True - ) - nsp_loss = tf.reduce_mean(nsp_loss) - loss = lm_loss + nsp_loss - - # Compute gradients - trainable_vars = self.trainable_variables - gradients = tape.gradient(loss, trainable_vars) - # Update weights - self.optimizer.apply_gradients(zip(gradients, trainable_vars)) - - # Update metrics - self.loss_tracker.update_state(loss) - self.lm_loss_tracker.update_state(lm_loss) - self.nsp_loss_tracker.update_state(nsp_loss) - self.lm_accuracy.update_state(lm_labels, lm_preds, lm_weights) - self.nsp_accuracy.update_state(nsp_labels, nsp_preds) - return {m.name: m.result() for m in self.metrics} - - class BertClassifier(keras.Model): """Classifier model with BertEncoder.""" From 483d781f11e12b84644d2073104e76372d8cfb4e Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 10 Aug 2022 02:46:29 +0000 Subject: [PATCH 07/45] Rename to BertPretrainingModel --- examples/bert/bert_train.py | 11 ++++++----- keras_nlp/applications/bert.py | 7 ++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 668bd84576..bbd728f972 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -188,7 +188,7 @@ def _gather_indexes(self, sequence_tensor, positions): return output_tensor -class BertLanguageModel(keras.Model): +class BertPretrainingModel(keras.Model): """ MLM + NSP model with BertEncoder. """ @@ -196,7 +196,7 @@ class BertLanguageModel(keras.Model): def __init__(self, encoder, **kwargs): super().__init__(**kwargs) self.encoder = encoder - # TODO(jbischof): replace with keras_nlp.layers.MLMHead + # TODO(jbischof): replace with keras_nlp.layers.MLMHead (Issue #166) self.masked_lm_head = MaskedLMHead( embedding_table=encoder.get_embedding_table(), initializer=encoder.initializer_fn, @@ -263,6 +263,7 @@ def train_step(self, data): self.nsp_accuracy.update_state(nsp_labels, nsp_preds) return {m.name: m.result() for m in self.metrics} + class LinearDecayWithWarmup(keras.optimizers.schedules.LearningRateSchedule): """ A learning rate schedule with linear warmup and decay. @@ -421,8 +422,8 @@ def main(_): ) optimizer = keras.optimizers.Adam(learning_rate=learning_rate_schedule) - language_model = BertLanguageModel(encoder) - language_model.compile( + pretraining_model = BertPretrainingModel(encoder) + pretraining_model.compile( optimizer=optimizer, ) @@ -435,7 +436,7 @@ def main(_): if FLAGS.tensorboard_log_path: callbacks.append(get_tensorboard_callback()) - language_model.fit( + pretraining_model.fit( dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, diff --git a/keras_nlp/applications/bert.py b/keras_nlp/applications/bert.py index 32a1757150..0f3098246b 100644 --- a/keras_nlp/applications/bert.py +++ b/keras_nlp/applications/bert.py @@ -19,7 +19,7 @@ import keras_nlp.layers # isort: off -# TODO(bischof): decide what to export or whether we are using these decorators +# TODO(jbischof): decide what to export or whether we are using these decorators #from tensorflow.python.util.tf_export import keras_export CLS_INDEX = 0 @@ -194,6 +194,7 @@ def get_config(self): class BertClassifier(keras.Model): """Classifier model with BertEncoder.""" + # TODO(jbischof): figure out initialization default def __init__(self, encoder, num_classes, initializer, **kwargs): super().__init__(**kwargs) self.encoder = encoder @@ -224,10 +225,10 @@ def BertBase(weights=None): dropout=0.1, ) - # TODO(bischof): add some documentation or magic to load our checkpoints + # TODO(jbischof): add some documentation or magic to load our checkpoints # Note: This is pure Keras and also intended to work with user checkpoints if weights is not None: model.load_weights(weights) - # TODO(bischof): attach the tokenizer + # TODO(jbischof): attach the tokenizer return model From 2d9fca7a274c5362bb29c201c76a00e42f8def05 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 10 Aug 2022 02:50:03 +0000 Subject: [PATCH 08/45] Small notes --- examples/bert/bert_config.py | 1 + keras_nlp/applications/bert.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/examples/bert/bert_config.py b/examples/bert/bert_config.py index e5febb9254..36c50e22dd 100644 --- a/examples/bert/bert_config.py +++ b/examples/bert/bert_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO(jbischof): remove in favor of BertBase, BertSmall, etc MODEL_CONFIGS = { "tiny": { "num_layers": 2, diff --git a/keras_nlp/applications/bert.py b/keras_nlp/applications/bert.py index 0f3098246b..2ccbe82384 100644 --- a/keras_nlp/applications/bert.py +++ b/keras_nlp/applications/bert.py @@ -88,6 +88,8 @@ def __init__( segment_id_input = keras.Input( shape=(None,), dtype="int32", name="segment_ids" ) + # TODO(jbischof): improve handling of masking following + # https://www.tensorflow.org/guide/keras/masking_and_padding input_mask = keras.Input(shape=(None,), dtype="int32", name="input_mask") # Embed tokens, positions, and segment ids. From e54e39c671f9802088670674dd66e588b0c23fca Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 10 Aug 2022 02:53:35 +0000 Subject: [PATCH 09/45] Formatting and notes --- examples/bert/bert_train.py | 6 +++++- keras_nlp/applications/__init__.py | 2 +- keras_nlp/applications/bert.py | 18 +++++++++++++----- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index bbd728f972..5cad8612be 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -89,6 +89,7 @@ "Override the pre-configured number of train steps..", ) + class MaskedLMHead(keras.layers.Layer): """Masked language model network head for BERT. @@ -223,7 +224,10 @@ def call(self, data): "segment_ids": data["segment_ids"], } ) - sequence_output, pooled_output = encoder_output["sequence_output"], encoder_output["pooled_output"] + sequence_output, pooled_output = ( + encoder_output["sequence_output"], + encoder_output["pooled_output"], + ) lm_preds = self.masked_lm_head( sequence_output, data["masked_lm_positions"] ) diff --git a/keras_nlp/applications/__init__.py b/keras_nlp/applications/__init__.py index 7b7fefd9da..c2171d9220 100644 --- a/keras_nlp/applications/__init__.py +++ b/keras_nlp/applications/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.applications.bert import BertClassifier from keras_nlp.applications.bert import Bert from keras_nlp.applications.bert import BertBase +from keras_nlp.applications.bert import BertClassifier diff --git a/keras_nlp/applications/bert.py b/keras_nlp/applications/bert.py index 2ccbe82384..15b155ef1d 100644 --- a/keras_nlp/applications/bert.py +++ b/keras_nlp/applications/bert.py @@ -20,12 +20,13 @@ # isort: off # TODO(jbischof): decide what to export or whether we are using these decorators -#from tensorflow.python.util.tf_export import keras_export +# from tensorflow.python.util.tf_export import keras_export CLS_INDEX = 0 TOKEN_EMBEDDING_LAYER_NAME = "token_embedding" +# TODO(jbischof): move to keras_nlp/models class Bert(keras.Model): """Bi-directional Transformer-based encoder network. @@ -62,6 +63,7 @@ class Bert(keras.Model): dense layers. If set False, output of attention and intermediate dense layers is normalized. """ + def __init__( self, vocab_size, @@ -84,13 +86,17 @@ def __init__( ) # Functional version of model - token_id_input = keras.Input(shape=(None,), dtype="int32", name="input_ids") + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="input_ids" + ) segment_id_input = keras.Input( shape=(None,), dtype="int32", name="segment_ids" ) # TODO(jbischof): improve handling of masking following # https://www.tensorflow.org/guide/keras/masking_and_padding - input_mask = keras.Input(shape=(None,), dtype="int32", name="input_mask") + input_mask = keras.Input( + shape=(None,), dtype="int32", name="input_mask" + ) # Embed tokens, positions, and segment ids. token_embedding = keras.layers.Embedding( @@ -150,11 +156,13 @@ def __init__( "segment_ids": segment_id_input, "input_mask": input_mask, }, + # TODO(jbischof): Consider list output outputs={ "sequence_output": sequence_output, "pooled_output": pooled_output, }, - **kwargs) + **kwargs + ) # All references to `self` below this line self.inner_activation_fn = inner_activation_fn self.initializer_fn = initializer_fn @@ -168,7 +176,7 @@ def __init__( self.inner_activation = keras.activations.get(inner_activation) self.initializer_range = initializer_range self.dropout = dropout - + def get_embedding_table(self): return self.get_layer(TOKEN_EMBEDDING_LAYER_NAME).embeddings From 1ac992bbf3ee414c26031ebbc14bd1d38014cbe9 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 10 Aug 2022 02:56:05 +0000 Subject: [PATCH 10/45] Note --- examples/bert/bert_preprocess.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/bert/bert_preprocess.py b/examples/bert/bert_preprocess.py index 8436bfa432..00e0e8cbca 100644 --- a/examples/bert/bert_preprocess.py +++ b/examples/bert/bert_preprocess.py @@ -375,7 +375,8 @@ def create_masked_lm_predictions( ): """Creates the predictions for the masked LM objective.""" - # TODO(bischof): replace with keras_nlp.layers.MLMMaskGenerator + # TODO(jbischof): replace with keras_nlp.layers.MLMMaskGenerator + # (Issue #166) cand_indexes = [] for (i, token) in enumerate(tokens): From 66b5c7ca2acedcee7b744ac2bcc4b635a2314390 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 10 Aug 2022 18:55:26 +0000 Subject: [PATCH 11/45] Move Bert to models/ folder --- examples/bert/bert_finetune_glue.py | 2 +- examples/bert/bert_train.py | 6 ++---- keras_nlp/__init__.py | 2 +- keras_nlp/{applications => models}/__init__.py | 6 +++--- keras_nlp/{applications => models}/bert.py | 0 5 files changed, 7 insertions(+), 9 deletions(-) rename keras_nlp/{applications => models}/__init__.py (79%) rename keras_nlp/{applications => models}/bert.py (100%) diff --git a/examples/bert/bert_finetune_glue.py b/examples/bert/bert_finetune_glue.py index f73c7e0e4b..49c6d17c10 100644 --- a/examples/bert/bert_finetune_glue.py +++ b/examples/bert/bert_finetune_glue.py @@ -26,7 +26,7 @@ from examples.bert.bert_config import FINETUNING_CONFIG from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG -from keras_nlp.applications import BertClassifier +from keras_nlp.models import BertClassifier FLAGS = flags.FLAGS diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 5cad8612be..0d31c72be5 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -24,7 +24,7 @@ from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG from examples.bert.bert_config import TRAINING_CONFIG -from keras_nlp.applications.bert import Bert +from keras_nlp.models import Bert FLAGS = flags.FLAGS @@ -190,9 +190,7 @@ def _gather_indexes(self, sequence_tensor, positions): class BertPretrainingModel(keras.Model): - """ - MLM + NSP model with BertEncoder. - """ + """MLM + NSP model with BertEncoder.""" def __init__(self, encoder, **kwargs): super().__init__(**kwargs) diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 182436118f..06907cf33d 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp import applications +from keras_nlp import models from keras_nlp import layers from keras_nlp import metrics from keras_nlp import tokenizers diff --git a/keras_nlp/applications/__init__.py b/keras_nlp/models/__init__.py similarity index 79% rename from keras_nlp/applications/__init__.py rename to keras_nlp/models/__init__.py index c2171d9220..261a563b09 100644 --- a/keras_nlp/applications/__init__.py +++ b/keras_nlp/models/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp.applications.bert import Bert -from keras_nlp.applications.bert import BertBase -from keras_nlp.applications.bert import BertClassifier +from keras_nlp.models.bert import Bert +from keras_nlp.models.bert import BertBase +from keras_nlp.models.bert import BertClassifier diff --git a/keras_nlp/applications/bert.py b/keras_nlp/models/bert.py similarity index 100% rename from keras_nlp/applications/bert.py rename to keras_nlp/models/bert.py From b3d22c0ffddad0258bdf2c9af8bd8513a67b6836 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 10 Aug 2022 22:41:33 +0000 Subject: [PATCH 12/45] Small style changes re: comments --- examples/bert/bert_config.py | 24 ++++++++--------- examples/bert/bert_train.py | 2 +- keras_nlp/__init__.py | 2 +- keras_nlp/models/bert.py | 50 +++++++++++++++++------------------- 4 files changed, 37 insertions(+), 41 deletions(-) diff --git a/examples/bert/bert_config.py b/examples/bert/bert_config.py index 36c50e22dd..24db7ef9d8 100644 --- a/examples/bert/bert_config.py +++ b/examples/bert/bert_config.py @@ -18,8 +18,8 @@ "num_layers": 2, "hidden_size": 128, "dropout": 0.1, - "num_attention_heads": 2, - "inner_size": 512, + "num_heads": 2, + "intermediate_dim": 512, "inner_activation": "gelu", "initializer_range": 0.02, }, @@ -27,8 +27,8 @@ "num_layers": 4, "hidden_size": 256, "dropout": 0.1, - "num_attention_heads": 4, - "inner_size": 1024, + "num_heads": 4, + "intermediate_dim": 1024, "inner_activation": "gelu", "initializer_range": 0.02, }, @@ -36,8 +36,8 @@ "num_layers": 4, "hidden_size": 512, "dropout": 0.1, - "num_attention_heads": 8, - "inner_size": 2048, + "num_heads": 8, + "intermediate_dim": 2048, "inner_activation": "gelu", "initializer_range": 0.02, }, @@ -45,8 +45,8 @@ "num_layers": 8, "hidden_size": 512, "dropout": 0.1, - "num_attention_heads": 8, - "inner_size": 2048, + "num_heads": 8, + "intermediate_dim": 2048, "inner_activation": "gelu", "initializer_range": 0.02, }, @@ -54,8 +54,8 @@ "num_layers": 12, "hidden_size": 768, "dropout": 0.1, - "num_attention_heads": 12, - "inner_size": 3072, + "num_heads": 12, + "intermediate_dim": 3072, "inner_activation": "gelu", "initializer_range": 0.02, }, @@ -63,8 +63,8 @@ "num_layers": 24, "hidden_size": 1024, "dropout": 0.1, - "num_attention_heads": 16, - "inner_size": 4096, + "num_heads": 16, + "intermediate_dim": 4096, "inner_activation": "gelu", "initializer_range": 0.02, }, diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 0d31c72be5..a3d0c0aeac 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -201,7 +201,7 @@ def __init__(self, encoder, **kwargs): initializer=encoder.initializer_fn, ) self.next_sentence_head = keras.layers.Dense( - encoder.type_vocab_size, + encoder.num_segments, kernel_initializer=encoder.initializer_fn, ) self.loss_tracker = keras.metrics.Mean(name="loss") diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 06907cf33d..67ab65bab1 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from keras_nlp import models from keras_nlp import layers from keras_nlp import metrics +from keras_nlp import models from keras_nlp import tokenizers from keras_nlp import utils diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 15b155ef1d..2436a56b75 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -16,17 +16,12 @@ import tensorflow as tf from tensorflow import keras -import keras_nlp.layers - -# isort: off -# TODO(jbischof): decide what to export or whether we are using these decorators -# from tensorflow.python.util.tf_export import keras_export +from keras_nlp.layers import PositionEmbedding, TransformerEncoder CLS_INDEX = 0 TOKEN_EMBEDDING_LAYER_NAME = "token_embedding" -# TODO(jbischof): move to keras_nlp/models class Bert(keras.Model): """Bi-directional Transformer-based encoder network. @@ -44,10 +39,10 @@ class Bert(keras.Model): vocab_size: The size of the token vocabulary. num_layers: The number of transformer layers. hidden_size: The size of the transformer hidden layers. - num_attention_heads: The number of attention heads for each transformer. + num_heads: The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads. - inner_size: The output dimension of the first Dense layer in a two-layer - feedforward network for each transformer. + intermediate_dim: The output dimension of the first Dense layer in a + two-layer feedforward network for each transformer. inner_activation: The activation for the first Dense layer in a two-layer feedforward network for each transformer. initializer_range: The initialzer range to use for a truncated normal @@ -57,7 +52,7 @@ class Bert(keras.Model): consume. If None, max_sequence_length uses the value from sequence length. This determines the variable shape for positional embeddings. - type_vocab_size: The number of types that the 'segment_ids' input can + num_segments: The number of types that the 'segment_ids' input can take. norm_first: Whether to normalize inputs to attention and intermediate dense layers. If set False, output of attention and intermediate @@ -69,14 +64,14 @@ def __init__( vocab_size, num_layers, hidden_size, - num_attention_heads, - inner_size, + num_heads, + intermediate_dim, inner_activation="gelu", initializer_range=0.02, dropout=0.1, max_sequence_length=512, - type_vocab_size=2, - **kwargs + num_segments=2, + **kwargs, ): # Create lambda functions from input params @@ -104,13 +99,13 @@ def __init__( output_dim=hidden_size, name=TOKEN_EMBEDDING_LAYER_NAME, )(token_id_input) - position_embedding = keras_nlp.layers.PositionEmbedding( + position_embedding = PositionEmbedding( initializer=initializer_fn, sequence_length=max_sequence_length, name="position_embedding", )(token_embedding) segment_embedding = keras.layers.Embedding( - input_dim=type_vocab_size, + input_dim=num_segments, output_dim=hidden_size, name="segment_embedding", )(segment_id_input) @@ -132,9 +127,9 @@ def __init__( # Apply successive transformer encoder blocks. for i in range(num_layers): - x = keras_nlp.layers.TransformerEncoder( - num_heads=num_attention_heads, - intermediate_dim=inner_size, + x = TransformerEncoder( + num_heads=num_heads, + intermediate_dim=intermediate_dim, activation=inner_activation_fn, dropout=dropout, kernel_initializer=initializer_fn, @@ -169,10 +164,10 @@ def __init__( self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_layers = num_layers - self.num_attention_heads = num_attention_heads + self.num_heads = num_heads self.max_sequence_length = max_sequence_length - self.type_vocab_size = type_vocab_size - self.inner_size = inner_size + self.num_segments = num_segments + self.intermediate_dim = intermediate_dim self.inner_activation = keras.activations.get(inner_activation) self.initializer_range = initializer_range self.dropout = dropout @@ -187,10 +182,10 @@ def get_config(self): "vocab_size": self.vocab_size, "hidden_size": self.hidden_size, "num_layers": self.num_layers, - "num_attention_heads": self.num_attention_heads, + "num_heads": self.num_heads, "max_sequence_length": self.max_sequence_length, - "type_vocab_size": self.type_vocab_size, - "inner_size": self.inner_size, + "num_segments": self.num_segments, + "intermediate_dim": self.intermediate_dim, "inner_activation": keras.activations.serialize( self.inner_activation ), @@ -223,13 +218,14 @@ def call(self, inputs): def BertBase(weights=None): """Factory for BertEncoder using "Base" architecture.""" + # TODO(jbischof): add docstring for `Bert` model = Bert( vocab_size=30522, num_layers=12, hidden_size=768, - num_attention_heads=12, - inner_size=3072, + num_heads=12, + intermediate_dim=3072, inner_activation="gelu", initializer_range=0.02, dropout=0.1, From b704f289e1ffd1e504f5bd99f37dd8766068f88e Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Wed, 10 Aug 2022 23:09:22 +0000 Subject: [PATCH 13/45] Fix Bert docstrings and remove `weights` param --- keras_nlp/models/bert.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 2436a56b75..333dc98810 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -31,10 +31,6 @@ class Bert(keras.Model): embedding lookups and transformer layers, but not the masked language model or classification task networks. - The default values for this object are taken from the BERT-Base - implementation in "BERT: Pre-training of Deep Bidirectional Transformers for - Language Understanding". - Args: vocab_size: The size of the token vocabulary. num_layers: The number of transformer layers. @@ -54,9 +50,6 @@ class Bert(keras.Model): embeddings. num_segments: The number of types that the 'segment_ids' input can take. - norm_first: Whether to normalize inputs to attention and intermediate - dense layers. If set False, output of attention and intermediate - dense layers is normalized. """ def __init__( @@ -216,9 +209,16 @@ def call(self, inputs): return self._logit_layer(pooled_output) -def BertBase(weights=None): - """Factory for BertEncoder using "Base" architecture.""" - # TODO(jbischof): add docstring for `Bert` +def BertBase(): + """ + Factory for BertEncoder using "Base" architecture. + + This network implements a bi-directional Transformer-based encoder as + described in "BERT: Pre-training of Deep Bidirectional Transformers for + Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the + embedding lookups and transformer layers, but not the masked language model + or classification task networks. + """ model = Bert( vocab_size=30522, @@ -232,9 +232,5 @@ def BertBase(weights=None): ) # TODO(jbischof): add some documentation or magic to load our checkpoints - # Note: This is pure Keras and also intended to work with user checkpoints - if weights is not None: - model.load_weights(weights) - # TODO(jbischof): attach the tokenizer return model From 4f1ecddbb56b65eed3648a566e44cfc9b5e52256 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 11 Aug 2022 01:04:09 +0000 Subject: [PATCH 14/45] Initialization and docstring for classifier --- examples/bert/bert_finetune_glue.py | 3 --- keras_nlp/models/bert.py | 29 +++++++++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/examples/bert/bert_finetune_glue.py b/examples/bert/bert_finetune_glue.py index 49c6d17c10..5d33f55b15 100644 --- a/examples/bert/bert_finetune_glue.py +++ b/examples/bert/bert_finetune_glue.py @@ -122,9 +122,6 @@ def build(self, hp): finetuning_model = BertClassifier( encoder=model, num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2, - initializer=keras.initializers.TruncatedNormal( - stddev=model_config["initializer_range"] - ), ) finetuning_model.compile( optimizer=keras.optimizers.Adam( diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 333dc98810..020fc5d3b0 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -16,7 +16,8 @@ import tensorflow as tf from tensorflow import keras -from keras_nlp.layers import PositionEmbedding, TransformerEncoder +from keras_nlp.layers import PositionEmbedding +from keras_nlp.layers import TransformerEncoder CLS_INDEX = 0 TOKEN_EMBEDDING_LAYER_NAME = "token_embedding" @@ -149,7 +150,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, - **kwargs + **kwargs, ) # All references to `self` below this line self.inner_activation_fn = inner_activation_fn @@ -190,16 +191,32 @@ def get_config(self): class BertClassifier(keras.Model): - """Classifier model with BertEncoder.""" + """ + Classifier model with BertEncoder. + + Args: + encoder: A `Bert` Model to encode inputs. + num_classes: Number of classes to predict. + kernel_initializer: Initializer for the `kernel` weights matrix. + bias_initializer: Initializer for the bias vector. + """ + + def __init__( + self, + encoder, + num_classes, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + **kwargs, + ): - # TODO(jbischof): figure out initialization default - def __init__(self, encoder, num_classes, initializer, **kwargs): super().__init__(**kwargs) self.encoder = encoder self.num_classes = num_classes self._logit_layer = keras.layers.Dense( num_classes, - kernel_initializer=initializer, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, name="logits", ) From caf9dc59171659b06525e19c2a90b5213a5ba4cb Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 11 Aug 2022 01:13:58 +0000 Subject: [PATCH 15/45] Decouple finetuning from model config --- examples/bert/bert_finetune_glue.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/examples/bert/bert_finetune_glue.py b/examples/bert/bert_finetune_glue.py index 5d33f55b15..ec82371203 100644 --- a/examples/bert/bert_finetune_glue.py +++ b/examples/bert/bert_finetune_glue.py @@ -24,18 +24,11 @@ import keras_nlp from examples.bert.bert_config import FINETUNING_CONFIG -from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG from keras_nlp.models import BertClassifier FLAGS = flags.FLAGS -flags.DEFINE_string( - "model_size", - "tiny", - "One of: tiny, mini, small, medium, base, or large.", -) - flags.DEFINE_string( "vocab_file", None, @@ -113,12 +106,8 @@ def to_tf_dataset(split): class BertHyperModel(keras_tuner.HyperModel): """Creates a hypermodel to help with the search space for finetuning.""" - def __init__(self, model_config): - self.model_config = model_config - def build(self, hp): model = keras.models.load_model(FLAGS.saved_model_input, compile=False) - model_config = self.model_config finetuning_model = BertClassifier( encoder=model, num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2, @@ -148,8 +137,6 @@ def main(_): end_value=tokenizer.token_to_id("[SEP]"), ) - model_config = MODEL_CONFIGS[FLAGS.model_size] - def preprocess_data(inputs, labels): inputs = [tokenizer(x) for x in inputs] token_ids, segment_ids = packer(inputs) @@ -174,7 +161,7 @@ def preprocess_data(inputs, labels): ) # Create a hypermodel object for a RandomSearch. - hypermodel = BertHyperModel(model_config) + hypermodel = BertHyperModel() # Initialize the random search over the 4 learning rate parameters, for 4 # trials and 3 epochs for each trial. From 772f201be998dd63fdd5ed28bfa17ca75adda054 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 11 Aug 2022 18:10:10 +0000 Subject: [PATCH 16/45] Replace `inner` -> `intermediate` --- examples/bert/bert_config.py | 12 ++++++------ examples/bert/bert_train.py | 10 ++++++---- keras_nlp/models/bert.py | 27 ++++++++++++++++----------- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/examples/bert/bert_config.py b/examples/bert/bert_config.py index 24db7ef9d8..cfed803d70 100644 --- a/examples/bert/bert_config.py +++ b/examples/bert/bert_config.py @@ -20,7 +20,7 @@ "dropout": 0.1, "num_heads": 2, "intermediate_dim": 512, - "inner_activation": "gelu", + "intermediate_activiation": "gelu", "initializer_range": 0.02, }, "mini": { @@ -29,7 +29,7 @@ "dropout": 0.1, "num_heads": 4, "intermediate_dim": 1024, - "inner_activation": "gelu", + "intermediate_activiation": "gelu", "initializer_range": 0.02, }, "small": { @@ -38,7 +38,7 @@ "dropout": 0.1, "num_heads": 8, "intermediate_dim": 2048, - "inner_activation": "gelu", + "intermediate_activiation": "gelu", "initializer_range": 0.02, }, "medium": { @@ -47,7 +47,7 @@ "dropout": 0.1, "num_heads": 8, "intermediate_dim": 2048, - "inner_activation": "gelu", + "intermediate_activiation": "gelu", "initializer_range": 0.02, }, "base": { @@ -56,7 +56,7 @@ "dropout": 0.1, "num_heads": 12, "intermediate_dim": 3072, - "inner_activation": "gelu", + "intermediate_activiation": "gelu", "initializer_range": 0.02, }, "large": { @@ -65,7 +65,7 @@ "dropout": 0.1, "num_heads": 16, "intermediate_dim": 4096, - "inner_activation": "gelu", + "intermediate_activiation": "gelu", "initializer_range": 0.02, }, } diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index a3d0c0aeac..8bc20ba8d5 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -105,7 +105,7 @@ class MaskedLMHead(keras.layers.Layer): Args: embedding_table: The embedding table from encoder network. - inner_activation: The activation, if any, for the inner dense layer. + intermediate_activiation: The activation, if any, for the inner dense layer. initializer: The initializer for the dense layer. Defaults to a Glorot uniform initializer. output: The output style for this layer. Can be either 'logits' or @@ -115,20 +115,22 @@ class MaskedLMHead(keras.layers.Layer): def __init__( self, embedding_table, - inner_activation="gelu", + intermediate_activiation="gelu", initializer="glorot_uniform", **kwargs, ): super().__init__(**kwargs) self.embedding_table = embedding_table - self.inner_activation = keras.activations.get(inner_activation) + self.intermediate_activiation = keras.activations.get( + intermediate_activiation + ) self.initializer = initializer def build(self, input_shape): self._vocab_size, hidden_size = self.embedding_table.shape self.dense = keras.layers.Dense( hidden_size, - activation=self.inner_activation, + activation=self.intermediate_activiation, kernel_initializer=self.initializer, name="transform/dense", ) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 020fc5d3b0..914347c277 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -40,7 +40,7 @@ class Bert(keras.Model): The hidden size must be divisible by the number of attention heads. intermediate_dim: The output dimension of the first Dense layer in a two-layer feedforward network for each transformer. - inner_activation: The activation for the first Dense layer in a + intermediate_activiation: The activation for the first Dense layer in a two-layer feedforward network for each transformer. initializer_range: The initialzer range to use for a truncated normal initializer. @@ -60,7 +60,7 @@ def __init__( hidden_size, num_heads, intermediate_dim, - inner_activation="gelu", + intermediate_activiation="gelu", initializer_range=0.02, dropout=0.1, max_sequence_length=512, @@ -69,7 +69,9 @@ def __init__( ): # Create lambda functions from input params - inner_activation_fn = keras.activations.get(inner_activation) + intermediate_activiation_fn = keras.activations.get( + intermediate_activiation + ) initializer_fn = keras.initializers.TruncatedNormal( stddev=initializer_range ) @@ -124,7 +126,7 @@ def __init__( x = TransformerEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, - activation=inner_activation_fn, + activation=intermediate_activiation_fn, dropout=dropout, kernel_initializer=initializer_fn, name="transformer/layer_%d" % i, @@ -139,7 +141,7 @@ def __init__( )(x[:, CLS_INDEX, :]) # Instantiate using Functional API Model constructor - super(Bert, self).__init__( + super().__init__( inputs={ "input_ids": token_id_input, "segment_ids": segment_id_input, @@ -153,7 +155,7 @@ def __init__( **kwargs, ) # All references to `self` below this line - self.inner_activation_fn = inner_activation_fn + self.intermediate_activiation_fn = intermediate_activiation_fn self.initializer_fn = initializer_fn self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -162,7 +164,9 @@ def __init__( self.max_sequence_length = max_sequence_length self.num_segments = num_segments self.intermediate_dim = intermediate_dim - self.inner_activation = keras.activations.get(inner_activation) + self.intermediate_activiation = keras.activations.get( + intermediate_activiation + ) self.initializer_range = initializer_range self.dropout = dropout @@ -180,8 +184,8 @@ def get_config(self): "max_sequence_length": self.max_sequence_length, "num_segments": self.num_segments, "intermediate_dim": self.intermediate_dim, - "inner_activation": keras.activations.serialize( - self.inner_activation + "intermediate_activiation": keras.activations.serialize( + self.intermediate_activiation ), "dropout": self.dropout, "initializer_range": self.initializer_range, @@ -226,7 +230,7 @@ def call(self, inputs): return self._logit_layer(pooled_output) -def BertBase(): +def BertBase(**kwargs): """ Factory for BertEncoder using "Base" architecture. @@ -243,9 +247,10 @@ def BertBase(): hidden_size=768, num_heads=12, intermediate_dim=3072, - inner_activation="gelu", + intermediate_activiation="gelu", initializer_range=0.02, dropout=0.1, + **kwargs, ) # TODO(jbischof): add some documentation or magic to load our checkpoints From 9bbc465dbcedc436ff3eed6cf9d94f365ed83e34 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 11 Aug 2022 19:05:38 +0000 Subject: [PATCH 17/45] Do not expose initializer and activation --- keras_nlp/models/bert.py | 31 +++++-------------------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 914347c277..a1039417dd 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -40,10 +40,6 @@ class Bert(keras.Model): The hidden size must be divisible by the number of attention heads. intermediate_dim: The output dimension of the first Dense layer in a two-layer feedforward network for each transformer. - intermediate_activiation: The activation for the first Dense layer in a - two-layer feedforward network for each transformer. - initializer_range: The initialzer range to use for a truncated normal - initializer. dropout: Dropout probability for the Transformer encoder. max_sequence_length: The maximum sequence length that this encoder can consume. If None, max_sequence_length uses the value from sequence @@ -60,21 +56,13 @@ def __init__( hidden_size, num_heads, intermediate_dim, - intermediate_activiation="gelu", - initializer_range=0.02, dropout=0.1, max_sequence_length=512, num_segments=2, **kwargs, ): - # Create lambda functions from input params - intermediate_activiation_fn = keras.activations.get( - intermediate_activiation - ) - initializer_fn = keras.initializers.TruncatedNormal( - stddev=initializer_range - ) + initializer_fn = keras.initializers.TruncatedNormal(stddev=0.02) # Functional version of model token_id_input = keras.Input( @@ -126,7 +114,9 @@ def __init__( x = TransformerEncoder( num_heads=num_heads, intermediate_dim=intermediate_dim, - activation=intermediate_activiation_fn, + activation=lambda x: keras.activations.gelu( + x, approximate=True + ), dropout=dropout, kernel_initializer=initializer_fn, name="transformer/layer_%d" % i, @@ -155,7 +145,6 @@ def __init__( **kwargs, ) # All references to `self` below this line - self.intermediate_activiation_fn = intermediate_activiation_fn self.initializer_fn = initializer_fn self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -164,10 +153,6 @@ def __init__( self.max_sequence_length = max_sequence_length self.num_segments = num_segments self.intermediate_dim = intermediate_dim - self.intermediate_activiation = keras.activations.get( - intermediate_activiation - ) - self.initializer_range = initializer_range self.dropout = dropout def get_embedding_table(self): @@ -184,11 +169,7 @@ def get_config(self): "max_sequence_length": self.max_sequence_length, "num_segments": self.num_segments, "intermediate_dim": self.intermediate_dim, - "intermediate_activiation": keras.activations.serialize( - self.intermediate_activiation - ), "dropout": self.dropout, - "initializer_range": self.initializer_range, } ) return config @@ -196,7 +177,7 @@ def get_config(self): class BertClassifier(keras.Model): """ - Classifier model with BertEncoder. + Adds a classification head to a Bert encoder model. Args: encoder: A `Bert` Model to encode inputs. @@ -247,8 +228,6 @@ def BertBase(**kwargs): hidden_size=768, num_heads=12, intermediate_dim=3072, - intermediate_activiation="gelu", - initializer_range=0.02, dropout=0.1, **kwargs, ) From 50774e9cd2acb9a29da1c3ddc5cf825ba9e06d92 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 11 Aug 2022 19:06:09 +0000 Subject: [PATCH 18/45] Do not exposed initializer and activation --- examples/bert/bert_config.py | 12 ------------ examples/bert/bert_train.py | 11 ++++++----- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/examples/bert/bert_config.py b/examples/bert/bert_config.py index cfed803d70..b4db8448c7 100644 --- a/examples/bert/bert_config.py +++ b/examples/bert/bert_config.py @@ -20,8 +20,6 @@ "dropout": 0.1, "num_heads": 2, "intermediate_dim": 512, - "intermediate_activiation": "gelu", - "initializer_range": 0.02, }, "mini": { "num_layers": 4, @@ -29,8 +27,6 @@ "dropout": 0.1, "num_heads": 4, "intermediate_dim": 1024, - "intermediate_activiation": "gelu", - "initializer_range": 0.02, }, "small": { "num_layers": 4, @@ -38,8 +34,6 @@ "dropout": 0.1, "num_heads": 8, "intermediate_dim": 2048, - "intermediate_activiation": "gelu", - "initializer_range": 0.02, }, "medium": { "num_layers": 8, @@ -47,8 +41,6 @@ "dropout": 0.1, "num_heads": 8, "intermediate_dim": 2048, - "intermediate_activiation": "gelu", - "initializer_range": 0.02, }, "base": { "num_layers": 12, @@ -56,8 +48,6 @@ "dropout": 0.1, "num_heads": 12, "intermediate_dim": 3072, - "intermediate_activiation": "gelu", - "initializer_range": 0.02, }, "large": { "num_layers": 24, @@ -65,8 +55,6 @@ "dropout": 0.1, "num_heads": 16, "intermediate_dim": 4096, - "intermediate_activiation": "gelu", - "initializer_range": 0.02, }, } diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 8bc20ba8d5..8cf2200b67 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -105,7 +105,8 @@ class MaskedLMHead(keras.layers.Layer): Args: embedding_table: The embedding table from encoder network. - intermediate_activiation: The activation, if any, for the inner dense layer. + intermediate_activation: The activation, if any, for the inner dense + layer. initializer: The initializer for the dense layer. Defaults to a Glorot uniform initializer. output: The output style for this layer. Can be either 'logits' or @@ -115,14 +116,14 @@ class MaskedLMHead(keras.layers.Layer): def __init__( self, embedding_table, - intermediate_activiation="gelu", + intermediate_activation="gelu", initializer="glorot_uniform", **kwargs, ): super().__init__(**kwargs) self.embedding_table = embedding_table - self.intermediate_activiation = keras.activations.get( - intermediate_activiation + self.intermediate_activation = keras.activations.get( + intermediate_activation ) self.initializer = initializer @@ -130,7 +131,7 @@ def build(self, input_shape): self._vocab_size, hidden_size = self.embedding_table.shape self.dense = keras.layers.Dense( hidden_size, - activation=self.intermediate_activiation, + activation=self.intermediate_activation, kernel_initializer=self.initializer, name="transform/dense", ) From ba4b7797b9e7459ffe56f1edd801035b32de9b7e Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 11 Aug 2022 23:22:33 +0000 Subject: [PATCH 19/45] Move BertClassifier to functional API --- keras_nlp/models/bert.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index a1039417dd..a6262143c9 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -16,6 +16,7 @@ import tensorflow as tf from tensorflow import keras +# TODO(jbischof): fix import style from keras_nlp.layers import PositionEmbedding from keras_nlp.layers import TransformerEncoder @@ -137,7 +138,6 @@ def __init__( "segment_ids": segment_id_input, "input_mask": input_mask, }, - # TODO(jbischof): Consider list output outputs={ "sequence_output": sequence_output, "pooled_output": pooled_output, @@ -186,6 +186,7 @@ class BertClassifier(keras.Model): bias_initializer: Initializer for the bias vector. """ + # TODO(jbischof): decide how to set defaults from `num_segments` def __init__( self, encoder, @@ -194,21 +195,17 @@ def __init__( bias_initializer="zeros", **kwargs, ): - - super().__init__(**kwargs) - self.encoder = encoder - self.num_classes = num_classes - self._logit_layer = keras.layers.Dense( + inputs = encoder.input + pooled = encoder(inputs)["pooled_output"] + outputs = keras.layers.Dense( num_classes, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, name="logits", - ) - - def call(self, inputs): - # Ignore the sequence output, use the pooled output. - pooled_output = self.encoder(inputs)["pooled_output"] - return self._logit_layer(pooled_output) + )(pooled) + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + self.encoder = encoder + self.num_classes = num_classes def BertBase(**kwargs): From 7a75f4c7b82fa40322ac2880c648bf2ebb2022ed Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Thu, 11 Aug 2022 23:33:47 +0000 Subject: [PATCH 20/45] Fix token embedding exposure in encoder --- keras_nlp/models/bert.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index a6262143c9..2f88d1acda 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -21,7 +21,6 @@ from keras_nlp.layers import TransformerEncoder CLS_INDEX = 0 -TOKEN_EMBEDDING_LAYER_NAME = "token_embedding" class Bert(keras.Model): @@ -79,11 +78,12 @@ def __init__( ) # Embed tokens, positions, and segment ids. - token_embedding = keras.layers.Embedding( + token_embedding_layer = keras.layers.Embedding( input_dim=vocab_size, output_dim=hidden_size, - name=TOKEN_EMBEDDING_LAYER_NAME, - )(token_id_input) + name="token_embedding", + ) + token_embedding = token_embedding_layer(token_id_input) position_embedding = PositionEmbedding( initializer=initializer_fn, sequence_length=max_sequence_length, @@ -145,6 +145,7 @@ def __init__( **kwargs, ) # All references to `self` below this line + self.token_embedding = token_embedding_layer self.initializer_fn = initializer_fn self.vocab_size = vocab_size self.hidden_size = hidden_size @@ -156,7 +157,7 @@ def __init__( self.dropout = dropout def get_embedding_table(self): - return self.get_layer(TOKEN_EMBEDDING_LAYER_NAME).embeddings + return self.token_embedding.embeddings def get_config(self): config = super().get_config() From 33e7d2133a715a5dffc71bf9f8c8a1292aa45fc4 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 00:57:32 +0000 Subject: [PATCH 21/45] Fix imports in `examples/` --- examples/bert/bert_finetune_glue.py | 3 +-- examples/bert/bert_train.py | 14 +++++++------- keras_nlp/models/bert.py | 15 +++++++++------ 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/examples/bert/bert_finetune_glue.py b/examples/bert/bert_finetune_glue.py index ec82371203..7815fc1c19 100644 --- a/examples/bert/bert_finetune_glue.py +++ b/examples/bert/bert_finetune_glue.py @@ -25,7 +25,6 @@ import keras_nlp from examples.bert.bert_config import FINETUNING_CONFIG from examples.bert.bert_config import PREPROCESSING_CONFIG -from keras_nlp.models import BertClassifier FLAGS = flags.FLAGS @@ -108,7 +107,7 @@ class BertHyperModel(keras_tuner.HyperModel): def build(self, hp): model = keras.models.load_model(FLAGS.saved_model_input, compile=False) - finetuning_model = BertClassifier( + finetuning_model = keras_nlp.models.BertClassifier( encoder=model, num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2, ) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 8cf2200b67..e29d1f7f07 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -21,10 +21,10 @@ from absl import logging from tensorflow import keras +import keras_nlp from examples.bert.bert_config import MODEL_CONFIGS from examples.bert.bert_config import PREPROCESSING_CONFIG from examples.bert.bert_config import TRAINING_CONFIG -from keras_nlp.models import Bert FLAGS = flags.FLAGS @@ -91,7 +91,7 @@ class MaskedLMHead(keras.layers.Layer): - """Masked language model network head for BERT. + """Masked language model network head for Bert. This layer implements a masked language model based on the provided transformer based encoder. It assumes that the encoder network being passed @@ -99,8 +99,8 @@ class MaskedLMHead(keras.layers.Layer): Example: ```python - encoder=modeling.networks.BertEncoder(...) - lm_layer=MaskedLMHead(embedding_table=encoder.get_embedding_table()) + encoder = keras_nlp.models.BertBase() + lm_layer = MaskedLMHead(embedding_table=encoder.get_embedding_table()) ``` Args: @@ -193,7 +193,7 @@ def _gather_indexes(self, sequence_tensor, positions): class BertPretrainingModel(keras.Model): - """MLM + NSP model with BertEncoder.""" + """MLM + NSP model with Bert encoder.""" def __init__(self, encoder, **kwargs): super().__init__(**kwargs) @@ -406,8 +406,8 @@ def main(_): dataset = dataset.repeat() with strategy.scope(): - # Create a BERT model the input config. - encoder = Bert(vocab_size=len(vocab), **model_config) + # Create a Bert model the input config. + encoder = keras_nlp.models.Bert(vocab_size=len(vocab), **model_config) # Make sure model has been called. encoder(encoder.inputs) encoder.summary() diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 2f88d1acda..6bf90d2945 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -49,6 +49,8 @@ class Bert(keras.Model): take. """ + # TODO(bischof): add tests + def __init__( self, vocab_size, @@ -71,8 +73,6 @@ def __init__( segment_id_input = keras.Input( shape=(None,), dtype="int32", name="segment_ids" ) - # TODO(jbischof): improve handling of masking following - # https://www.tensorflow.org/guide/keras/masking_and_padding input_mask = keras.Input( shape=(None,), dtype="int32", name="input_mask" ) @@ -97,7 +97,7 @@ def __init__( # Sum, normailze and apply dropout to embeddings. x = keras.layers.Add( - name="embedding_sum", + name="embeddings/sum", )((token_embedding, position_embedding, segment_embedding)) x = keras.layers.LayerNormalization( name="embeddings/layer_norm", @@ -107,7 +107,7 @@ def __init__( )(x) x = keras.layers.Dropout( dropout, - name="embedding_dropout", + name="embeddings/dropout", )(x) # Apply successive transformer encoder blocks. @@ -123,7 +123,8 @@ def __init__( name="transformer/layer_%d" % i, )(x, padding_mask=input_mask) - # Construct the two BERT outputs, and apply a dense to the pooled output. + # Construct the two BERT outputs. The pooled output is a dense layer on + # top of the [CLS] token. sequence_output = x pooled_output = keras.layers.Dense( hidden_size, @@ -204,14 +205,16 @@ def __init__( bias_initializer=bias_initializer, name="logits", )(pooled) + # Instantiate using Functional API Model constructor super().__init__(inputs=inputs, outputs=outputs, **kwargs) + # All references to `self` below this line self.encoder = encoder self.num_classes = num_classes def BertBase(**kwargs): """ - Factory for BertEncoder using "Base" architecture. + Factory for Bert using "Base" architecture. This network implements a bi-directional Transformer-based encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers for From 7ef349cac70aa6373f77cae4918fd3a57a730897 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 01:00:17 +0000 Subject: [PATCH 22/45] remove TODO --- keras_nlp/models/bert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 6bf90d2945..76363de715 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -16,7 +16,6 @@ import tensorflow as tf from tensorflow import keras -# TODO(jbischof): fix import style from keras_nlp.layers import PositionEmbedding from keras_nlp.layers import TransformerEncoder From c59626732b8f7dc079305321e41fcb872795aae8 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 03:27:28 +0000 Subject: [PATCH 23/45] Add docstring test for Bert encoder --- keras_nlp/layers/transformer_encoder.py | 1 - keras_nlp/models/bert.py | 27 ++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/keras_nlp/layers/transformer_encoder.py b/keras_nlp/layers/transformer_encoder.py index c67409d911..08503447ba 100644 --- a/keras_nlp/layers/transformer_encoder.py +++ b/keras_nlp/layers/transformer_encoder.py @@ -68,7 +68,6 @@ class TransformerEncoder(keras.layers.Layer): # Call encoder on the inputs. input_data = tf.random.uniform(shape=[2, 10, 64]) output = model(input_data) - ``` References: diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 76363de715..7b3082a649 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -45,7 +45,32 @@ class Bert(keras.Model): length. This determines the variable shape for positional embeddings. num_segments: The number of types that the 'segment_ids' input can - take. + take. + + Example: + ```python + # Randomly initialized Bert encoder + encoder = keras_nlp.models.Bert( + vocab_size=30522, + num_layers=12, + hidden_size=768, + num_heads=12, + intermediate_dim=3072, + dropout=0.1, + max_sequence_length=12 + ) + + # Call encoder on the inputs. + input_data = { + "input_ids": tf.random.uniform( + shape=(1, 12), dtype=tf.int64, maxval=30522), + "segment_ids": tf.constant( + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)), + "input_mask": tf.constant( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)), + } + output = encoder(input_data) + ``` """ # TODO(bischof): add tests From 52d08a04830b0fc293c184ebe34708f7b1ec2510 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 03:30:51 +0000 Subject: [PATCH 24/45] Format --- keras_nlp/models/bert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 7b3082a649..a67ea28eee 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -45,8 +45,8 @@ class Bert(keras.Model): length. This determines the variable shape for positional embeddings. num_segments: The number of types that the 'segment_ids' input can - take. - + take. + Example: ```python # Randomly initialized Bert encoder From 3cbf70960a5cf5bbee0cba0dd28c6135897d1e25 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 03:40:22 +0000 Subject: [PATCH 25/45] Add docstring test for classifier --- keras_nlp/models/bert.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index a67ea28eee..f67ccae4e8 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -47,7 +47,7 @@ class Bert(keras.Model): num_segments: The number of types that the 'segment_ids' input can take. - Example: + Example usage: ```python # Randomly initialized Bert encoder encoder = keras_nlp.models.Bert( @@ -210,6 +210,30 @@ class BertClassifier(keras.Model): num_classes: Number of classes to predict. kernel_initializer: Initializer for the `kernel` weights matrix. bias_initializer: Initializer for the bias vector. + + Example usage: + # Randomly initialized Bert encoder + encoder = keras_nlp.models.Bert( + vocab_size=30522, + num_layers=12, + hidden_size=768, + num_heads=12, + intermediate_dim=3072, + dropout=0.1, + max_sequence_length=12 + ) + + # Call classifier on the inputs. + input_data = { + "input_ids": tf.random.uniform( + shape=(1, 12), dtype=tf.int64, maxval=30522), + "segment_ids": tf.constant( + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)), + "input_mask": tf.constant( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)), + } + classifier = keras_nlp.models.BertClassifier(encoder, 4) + logits = classifier(input_data) """ # TODO(jbischof): decide how to set defaults from `num_segments` From b2d605369b67eea51e0fbdac6445f3395f62a05b Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 03:40:34 +0000 Subject: [PATCH 26/45] Format --- keras_nlp/models/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index f67ccae4e8..23f6917144 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -210,7 +210,7 @@ class BertClassifier(keras.Model): num_classes: Number of classes to predict. kernel_initializer: Initializer for the `kernel` weights matrix. bias_initializer: Initializer for the bias vector. - + Example usage: # Randomly initialized Bert encoder encoder = keras_nlp.models.Bert( From 5fc84edce184787fb23dba7eb2237787034e816d Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 03:41:55 +0000 Subject: [PATCH 27/45] Set `max_sequence_length` for BertBase --- keras_nlp/models/bert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 23f6917144..0a6c10defe 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -278,6 +278,7 @@ def BertBase(**kwargs): num_heads=12, intermediate_dim=3072, dropout=0.1, + max_sequence_length=512, **kwargs, ) From a4c7f8d647687b7d0b348ff9fc4bbe30c435998b Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 03:50:44 +0000 Subject: [PATCH 28/45] Move TODO --- keras_nlp/models/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 0a6c10defe..d6e93fd054 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -236,7 +236,6 @@ class BertClassifier(keras.Model): logits = classifier(input_data) """ - # TODO(jbischof): decide how to set defaults from `num_segments` def __init__( self, encoder, @@ -271,6 +270,7 @@ def BertBase(**kwargs): or classification task networks. """ + # TODO(jbischof): decide how to set defaults from `num_segments` model = Bert( vocab_size=30522, num_layers=12, From 94f81b2bfb1d93b42f429efdd807f3d676399451 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 19:21:39 +0000 Subject: [PATCH 29/45] Standarize initializers to match Bert paper --- examples/bert/bert_train.py | 4 ++-- keras_nlp/models/bert.py | 29 ++++++++++++++++------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index e29d1f7f07..1aaf20b421 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -201,11 +201,11 @@ def __init__(self, encoder, **kwargs): # TODO(jbischof): replace with keras_nlp.layers.MLMHead (Issue #166) self.masked_lm_head = MaskedLMHead( embedding_table=encoder.get_embedding_table(), - initializer=encoder.initializer_fn, + initializer=keras.initializers.TruncatedNormal(stddev=0.02), ) self.next_sentence_head = keras.layers.Dense( encoder.num_segments, - kernel_initializer=encoder.initializer_fn, + kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02), ) self.loss_tracker = keras.metrics.Mean(name="loss") self.lm_loss_tracker = keras.metrics.Mean(name="lm_loss") diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index d6e93fd054..210ee24d4b 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -22,6 +22,10 @@ CLS_INDEX = 0 +def _bert_kernel_initializer(stddev=0.02): + return keras.initializers.TruncatedNormal(stddev=stddev) + + class Bert(keras.Model): """Bi-directional Transformer-based encoder network. @@ -88,8 +92,6 @@ def __init__( **kwargs, ): - initializer_fn = keras.initializers.TruncatedNormal(stddev=0.02) - # Functional version of model token_id_input = keras.Input( shape=(None,), dtype="int32", name="input_ids" @@ -105,17 +107,19 @@ def __init__( token_embedding_layer = keras.layers.Embedding( input_dim=vocab_size, output_dim=hidden_size, + embeddings_initializer=_bert_kernel_initializer(), name="token_embedding", ) token_embedding = token_embedding_layer(token_id_input) position_embedding = PositionEmbedding( - initializer=initializer_fn, + initializer=_bert_kernel_initializer(), sequence_length=max_sequence_length, name="position_embedding", )(token_embedding) segment_embedding = keras.layers.Embedding( input_dim=num_segments, output_dim=hidden_size, + embeddings_initializer=_bert_kernel_initializer(), name="segment_embedding", )(segment_id_input) @@ -143,7 +147,7 @@ def __init__( x, approximate=True ), dropout=dropout, - kernel_initializer=initializer_fn, + kernel_initializer=_bert_kernel_initializer(), name="transformer/layer_%d" % i, )(x, padding_mask=input_mask) @@ -152,6 +156,7 @@ def __init__( sequence_output = x pooled_output = keras.layers.Dense( hidden_size, + kernel_initializer=_bert_kernel_initializer(), activation="tanh", name="pooled_dense", )(x[:, CLS_INDEX, :]) @@ -171,7 +176,6 @@ def __init__( ) # All references to `self` below this line self.token_embedding = token_embedding_layer - self.initializer_fn = initializer_fn self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_layers = num_layers @@ -208,9 +212,9 @@ class BertClassifier(keras.Model): Args: encoder: A `Bert` Model to encode inputs. num_classes: Number of classes to predict. - kernel_initializer: Initializer for the `kernel` weights matrix. - bias_initializer: Initializer for the bias vector. + ``` + python Example usage: # Randomly initialized Bert encoder encoder = keras_nlp.models.Bert( @@ -234,22 +238,20 @@ class BertClassifier(keras.Model): } classifier = keras_nlp.models.BertClassifier(encoder, 4) logits = classifier(input_data) + ``` """ def __init__( self, encoder, num_classes, - kernel_initializer="glorot_uniform", - bias_initializer="zeros", **kwargs, ): inputs = encoder.input pooled = encoder(inputs)["pooled_output"] outputs = keras.layers.Dense( num_classes, - kernel_initializer=kernel_initializer, - bias_initializer=bias_initializer, + kernel_initializer=_bert_kernel_initializer(), name="logits", )(pooled) # Instantiate using Functional API Model constructor @@ -261,7 +263,8 @@ def __init__( def BertBase(**kwargs): """ - Factory for Bert using "Base" architecture. + Bi-directional Transformer-based encoder network (Bert) using "Base" + architecture. This network implements a bi-directional Transformer-based encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers for @@ -270,7 +273,6 @@ def BertBase(**kwargs): or classification task networks. """ - # TODO(jbischof): decide how to set defaults from `num_segments` model = Bert( vocab_size=30522, num_layers=12, @@ -279,6 +281,7 @@ def BertBase(**kwargs): intermediate_dim=3072, dropout=0.1, max_sequence_length=512, + num_segments=2, **kwargs, ) From b99a84c3cc04b405b62d2c1d551709297f597936 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 22:13:00 +0000 Subject: [PATCH 30/45] Respond to minor comments --- examples/bert/bert_train.py | 6 +++-- keras_nlp/models/bert.py | 54 +++++++++++++++++-------------------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 1aaf20b421..55eadbede2 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -200,7 +200,7 @@ def __init__(self, encoder, **kwargs): self.encoder = encoder # TODO(jbischof): replace with keras_nlp.layers.MLMHead (Issue #166) self.masked_lm_head = MaskedLMHead( - embedding_table=encoder.get_embedding_table(), + embedding_table=encoder.token_embedding.embeddings, initializer=keras.initializers.TruncatedNormal(stddev=0.02), ) self.next_sentence_head = keras.layers.Dense( @@ -407,7 +407,9 @@ def main(_): with strategy.scope(): # Create a Bert model the input config. - encoder = keras_nlp.models.Bert(vocab_size=len(vocab), **model_config) + encoder = keras_nlp.models.Bert( + vocabulary_size=len(vocab), **model_config + ) # Make sure model has been called. encoder(encoder.inputs) encoder.summary() diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 210ee24d4b..08fdb66c6c 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -19,8 +19,6 @@ from keras_nlp.layers import PositionEmbedding from keras_nlp.layers import TransformerEncoder -CLS_INDEX = 0 - def _bert_kernel_initializer(stddev=0.02): return keras.initializers.TruncatedNormal(stddev=stddev) @@ -36,7 +34,7 @@ class Bert(keras.Model): or classification task networks. Args: - vocab_size: The size of the token vocabulary. + vocabulary_size: The size of the token vocabulary. num_layers: The number of transformer layers. hidden_size: The size of the transformer hidden layers. num_heads: The number of attention heads for each transformer. @@ -50,17 +48,17 @@ class Bert(keras.Model): embeddings. num_segments: The number of types that the 'segment_ids' input can take. + cls_token_index: Index of [CLS] token in the vocabulary. Example usage: ```python # Randomly initialized Bert encoder encoder = keras_nlp.models.Bert( - vocab_size=30522, + vocabulary_size=30522, num_layers=12, hidden_size=768, num_heads=12, intermediate_dim=3072, - dropout=0.1, max_sequence_length=12 ) @@ -81,7 +79,7 @@ class Bert(keras.Model): def __init__( self, - vocab_size, + vocabulary_size, num_layers, hidden_size, num_heads, @@ -89,10 +87,10 @@ def __init__( dropout=0.1, max_sequence_length=512, num_segments=2, + cls_token_index=0, **kwargs, ): - # Functional version of model token_id_input = keras.Input( shape=(None,), dtype="int32", name="input_ids" ) @@ -105,7 +103,7 @@ def __init__( # Embed tokens, positions, and segment ids. token_embedding_layer = keras.layers.Embedding( - input_dim=vocab_size, + input_dim=vocabulary_size, output_dim=hidden_size, embeddings_initializer=_bert_kernel_initializer(), name="token_embedding", @@ -124,18 +122,18 @@ def __init__( )(segment_id_input) # Sum, normailze and apply dropout to embeddings. - x = keras.layers.Add( - name="embeddings/sum", - )((token_embedding, position_embedding, segment_embedding)) + x = keras.layers.Add()( + (token_embedding, position_embedding, segment_embedding) + ) x = keras.layers.LayerNormalization( - name="embeddings/layer_norm", + name="embeddings_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32, )(x) x = keras.layers.Dropout( dropout, - name="embeddings/dropout", + name="embeddings_dropout", )(x) # Apply successive transformer encoder blocks. @@ -148,10 +146,10 @@ def __init__( ), dropout=dropout, kernel_initializer=_bert_kernel_initializer(), - name="transformer/layer_%d" % i, + name=f"""transformer_layer_{i}""", )(x, padding_mask=input_mask) - # Construct the two BERT outputs. The pooled output is a dense layer on + # Construct the two Bert outputs. The pooled output is a dense layer on # top of the [CLS] token. sequence_output = x pooled_output = keras.layers.Dense( @@ -159,7 +157,7 @@ def __init__( kernel_initializer=_bert_kernel_initializer(), activation="tanh", name="pooled_dense", - )(x[:, CLS_INDEX, :]) + )(x[:, cls_token_index, :]) # Instantiate using Functional API Model constructor super().__init__( @@ -176,7 +174,7 @@ def __init__( ) # All references to `self` below this line self.token_embedding = token_embedding_layer - self.vocab_size = vocab_size + self.vocabulary_size = vocabulary_size self.hidden_size = hidden_size self.num_layers = num_layers self.num_heads = num_heads @@ -184,15 +182,13 @@ def __init__( self.num_segments = num_segments self.intermediate_dim = intermediate_dim self.dropout = dropout - - def get_embedding_table(self): - return self.token_embedding.embeddings + self.cls_token_index = cls_token_index def get_config(self): config = super().get_config() config.update( { - "vocab_size": self.vocab_size, + "vocabulary_size": self.vocabulary_size, "hidden_size": self.hidden_size, "num_layers": self.num_layers, "num_heads": self.num_heads, @@ -200,30 +196,29 @@ def get_config(self): "num_segments": self.num_segments, "intermediate_dim": self.intermediate_dim, "dropout": self.dropout, + "cls_token_index": self.cls_token_index, } ) return config class BertClassifier(keras.Model): - """ - Adds a classification head to a Bert encoder model. + """Bert encoder model with a classification head. Args: - encoder: A `Bert` Model to encode inputs. + encoder: A `keras_nlp.models.Bert` to encode inputs. num_classes: Number of classes to predict. + Example usage: ``` python - Example usage: # Randomly initialized Bert encoder encoder = keras_nlp.models.Bert( - vocab_size=30522, + vocabulary_size=30522, num_layers=12, hidden_size=768, num_heads=12, intermediate_dim=3072, - dropout=0.1, max_sequence_length=12 ) @@ -262,8 +257,7 @@ def __init__( def BertBase(**kwargs): - """ - Bi-directional Transformer-based encoder network (Bert) using "Base" + """Bi-directional Transformer-based encoder network (Bert) using "Base" architecture. This network implements a bi-directional Transformer-based encoder as @@ -274,7 +268,7 @@ def BertBase(**kwargs): """ model = Bert( - vocab_size=30522, + vocabulary_size=30522, num_layers=12, hidden_size=768, num_heads=12, From a8a6dc862c6c0ebb0e5fe4038a92a6dac16527a0 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 22:33:53 +0000 Subject: [PATCH 31/45] More minor comment fixes --- examples/bert/bert_config.py | 12 ++++---- examples/bert/bert_finetune_glue.py | 2 +- examples/bert/bert_train.py | 8 +++--- keras_nlp/layers/sine_position_encoding.py | 2 ++ keras_nlp/models/bert.py | 32 ++++++++++++---------- 5 files changed, 30 insertions(+), 26 deletions(-) diff --git a/examples/bert/bert_config.py b/examples/bert/bert_config.py index b4db8448c7..6a0b890209 100644 --- a/examples/bert/bert_config.py +++ b/examples/bert/bert_config.py @@ -16,42 +16,42 @@ MODEL_CONFIGS = { "tiny": { "num_layers": 2, - "hidden_size": 128, + "hidden_dim": 128, "dropout": 0.1, "num_heads": 2, "intermediate_dim": 512, }, "mini": { "num_layers": 4, - "hidden_size": 256, + "hidden_dim": 256, "dropout": 0.1, "num_heads": 4, "intermediate_dim": 1024, }, "small": { "num_layers": 4, - "hidden_size": 512, + "hidden_dim": 512, "dropout": 0.1, "num_heads": 8, "intermediate_dim": 2048, }, "medium": { "num_layers": 8, - "hidden_size": 512, + "hidden_dim": 512, "dropout": 0.1, "num_heads": 8, "intermediate_dim": 2048, }, "base": { "num_layers": 12, - "hidden_size": 768, + "hidden_dim": 768, "dropout": 0.1, "num_heads": 12, "intermediate_dim": 3072, }, "large": { "num_layers": 24, - "hidden_size": 1024, + "hidden_dim": 1024, "dropout": 0.1, "num_heads": 16, "intermediate_dim": 4096, diff --git a/examples/bert/bert_finetune_glue.py b/examples/bert/bert_finetune_glue.py index 7815fc1c19..6b1beec6e9 100644 --- a/examples/bert/bert_finetune_glue.py +++ b/examples/bert/bert_finetune_glue.py @@ -108,7 +108,7 @@ class BertHyperModel(keras_tuner.HyperModel): def build(self, hp): model = keras.models.load_model(FLAGS.saved_model_input, compile=False) finetuning_model = keras_nlp.models.BertClassifier( - encoder=model, + base_model=model, num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2, ) finetuning_model.compile( diff --git a/examples/bert/bert_train.py b/examples/bert/bert_train.py index 55eadbede2..ff5480f3ba 100644 --- a/examples/bert/bert_train.py +++ b/examples/bert/bert_train.py @@ -128,9 +128,9 @@ def __init__( self.initializer = initializer def build(self, input_shape): - self._vocab_size, hidden_size = self.embedding_table.shape + self._vocab_size, hidden_dim = self.embedding_table.shape self.dense = keras.layers.Dense( - hidden_size, + hidden_dim, activation=self.intermediate_activation, kernel_initializer=self.initializer, name="transform/dense", @@ -165,7 +165,7 @@ def _gather_indexes(self, sequence_tensor, positions): Args: sequence_tensor: Sequence output of shape - (`batch_size`, `seq_length`, `hidden_size`) where `hidden_size` + (`batch_size`, `seq_length`, `hidden_dim`) where `hidden_dim` is number of hidden units. positions: Positions ids of tokens in sequence to mask for pretraining of with dimension (batch_size, num_predictions) @@ -174,7 +174,7 @@ def _gather_indexes(self, sequence_tensor, positions): Returns: Masked out sequence tensor of shape (batch_size * num_predictions, - `hidden_size`). + `hidden_dim`). """ sequence_shape = tf.shape(sequence_tensor) batch_size, seq_length = sequence_shape[0], sequence_shape[1] diff --git a/keras_nlp/layers/sine_position_encoding.py b/keras_nlp/layers/sine_position_encoding.py index 0ce8ec3c5f..e3aca37838 100644 --- a/keras_nlp/layers/sine_position_encoding.py +++ b/keras_nlp/layers/sine_position_encoding.py @@ -62,6 +62,8 @@ def __init__( self.max_wavelength = max_wavelength def call(self, inputs): + # TODO(jbischof): replace `hidden_size` with`hidden_dim` for consistency + # with other layers. input_shape = tf.shape(inputs) # length of sequence is the second last dimension of the inputs seq_length = input_shape[-2] diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 08fdb66c6c..5dd9cdcb72 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -36,9 +36,9 @@ class Bert(keras.Model): Args: vocabulary_size: The size of the token vocabulary. num_layers: The number of transformer layers. - hidden_size: The size of the transformer hidden layers. num_heads: The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads. + hidden_dim: The size of the transformer hidden layers. intermediate_dim: The output dimension of the first Dense layer in a two-layer feedforward network for each transformer. dropout: Dropout probability for the Transformer encoder. @@ -56,8 +56,8 @@ class Bert(keras.Model): encoder = keras_nlp.models.Bert( vocabulary_size=30522, num_layers=12, - hidden_size=768, num_heads=12, + hidden_dim=768, intermediate_dim=3072, max_sequence_length=12 ) @@ -81,8 +81,8 @@ def __init__( self, vocabulary_size, num_layers, - hidden_size, num_heads, + hidden_dim, intermediate_dim, dropout=0.1, max_sequence_length=512, @@ -104,7 +104,7 @@ def __init__( # Embed tokens, positions, and segment ids. token_embedding_layer = keras.layers.Embedding( input_dim=vocabulary_size, - output_dim=hidden_size, + output_dim=hidden_dim, embeddings_initializer=_bert_kernel_initializer(), name="token_embedding", ) @@ -116,7 +116,7 @@ def __init__( )(token_embedding) segment_embedding = keras.layers.Embedding( input_dim=num_segments, - output_dim=hidden_size, + output_dim=hidden_dim, embeddings_initializer=_bert_kernel_initializer(), name="segment_embedding", )(segment_id_input) @@ -153,7 +153,7 @@ def __init__( # top of the [CLS] token. sequence_output = x pooled_output = keras.layers.Dense( - hidden_size, + hidden_dim, kernel_initializer=_bert_kernel_initializer(), activation="tanh", name="pooled_dense", @@ -175,7 +175,8 @@ def __init__( # All references to `self` below this line self.token_embedding = token_embedding_layer self.vocabulary_size = vocabulary_size - self.hidden_size = hidden_size + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim self.num_layers = num_layers self.num_heads = num_heads self.max_sequence_length = max_sequence_length @@ -189,7 +190,8 @@ def get_config(self): config.update( { "vocabulary_size": self.vocabulary_size, - "hidden_size": self.hidden_size, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, "num_layers": self.num_layers, "num_heads": self.num_heads, "max_sequence_length": self.max_sequence_length, @@ -206,7 +208,7 @@ class BertClassifier(keras.Model): """Bert encoder model with a classification head. Args: - encoder: A `keras_nlp.models.Bert` to encode inputs. + base_model: A `keras_nlp.models.Bert` to encode inputs. num_classes: Number of classes to predict. Example usage: @@ -216,8 +218,8 @@ class BertClassifier(keras.Model): encoder = keras_nlp.models.Bert( vocabulary_size=30522, num_layers=12, - hidden_size=768, num_heads=12, + hidden_dim=768, intermediate_dim=3072, max_sequence_length=12 ) @@ -238,12 +240,12 @@ class BertClassifier(keras.Model): def __init__( self, - encoder, + base_model, num_classes, **kwargs, ): - inputs = encoder.input - pooled = encoder(inputs)["pooled_output"] + inputs = base_model.input + pooled = base_model(inputs)["pooled_output"] outputs = keras.layers.Dense( num_classes, kernel_initializer=_bert_kernel_initializer(), @@ -252,7 +254,7 @@ def __init__( # Instantiate using Functional API Model constructor super().__init__(inputs=inputs, outputs=outputs, **kwargs) # All references to `self` below this line - self.encoder = encoder + self.base_model = base_model self.num_classes = num_classes @@ -270,8 +272,8 @@ def BertBase(**kwargs): model = Bert( vocabulary_size=30522, num_layers=12, - hidden_size=768, num_heads=12, + hidden_dim=768, intermediate_dim=3072, dropout=0.1, max_sequence_length=512, From 4053bcb8cbbd99578ae18f448d1d0029adc03b07 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Fri, 12 Aug 2022 22:46:38 +0000 Subject: [PATCH 32/45] Format fix --- keras_nlp/models/bert.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 5dd9cdcb72..0c086340bf 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -48,7 +48,7 @@ class Bert(keras.Model): embeddings. num_segments: The number of types that the 'segment_ids' input can take. - cls_token_index: Index of [CLS] token in the vocabulary. + cls_token_index: Index of [CLS] token in the vocabulary. Example usage: ```python @@ -196,7 +196,6 @@ def get_config(self): "num_heads": self.num_heads, "max_sequence_length": self.max_sequence_length, "num_segments": self.num_segments, - "intermediate_dim": self.intermediate_dim, "dropout": self.dropout, "cls_token_index": self.cls_token_index, } From 07753077b5b3bde36c4d6d8462ffa92d55ecbe78 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 21:13:12 +0000 Subject: [PATCH 33/45] Improve documentation --- keras_nlp/models/bert.py | 58 +++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 0c086340bf..1655c3b67b 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -15,6 +15,7 @@ import tensorflow as tf from tensorflow import keras +from absl import logging from keras_nlp.layers import PositionEmbedding from keras_nlp.layers import TransformerEncoder @@ -28,11 +29,15 @@ class Bert(keras.Model): """Bi-directional Transformer-based encoder network. This network implements a bi-directional Transformer-based encoder as - described in "BERT: Pre-training of Deep Bidirectional Transformers for - Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the + described in ["BERT: Pre-training of Deep Bidirectional Transformers for + Language Understanding"](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks. + This class gives a fully configurable Bert model with any number of layers, + heads, and embedding dimensions. For specific specific bert architectures + defined in the paper, see for example `keras_nlp.models.BertBase`. + Args: vocabulary_size: The size of the token vocabulary. num_layers: The number of transformer layers. @@ -262,23 +267,46 @@ def BertBase(**kwargs): architecture. This network implements a bi-directional Transformer-based encoder as - described in "BERT: Pre-training of Deep Bidirectional Transformers for - Language Understanding" (https://arxiv.org/abs/1810.04805). It includes the + described in ["BERT: Pre-training of Deep Bidirectional Transformers for + Language Understanding"](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks. + + Example usage: + ```python + # Randomly initialized BertBase encoder + encoder = keras_nlp.models.BertBase() + + # Call encoder on the inputs. + input_data = { + "input_ids": tf.random.uniform( + shape=(1, 512), dtype=tf.int64, maxval=encoder.vocabulary_size), + "segment_ids": tf.constant( + [0] * 200 + [1] * 312, shape=(1, 512)), + "input_mask": tf.constant( + [1] * 512, shape=(1, 512)), + } + output = encoder(input_data) """ - model = Bert( - vocabulary_size=30522, - num_layers=12, - num_heads=12, - hidden_dim=768, - intermediate_dim=3072, - dropout=0.1, - max_sequence_length=512, - num_segments=2, - **kwargs, - ) + base_args = { + "vocabulary_size": 30522, + "num_layers": 12, + "num_heads": 12, + "hidden_dim": 768, + "intermediate_dim": 3072, + "dropout": 0.1, + "max_sequence_length": 512, + "num_segments": 2, + } + + for arg in kwargs: + if arg in base_args: + logging.error( + f"""`{arg}` fixed to {base_args[arg]} in BertBase and cannot """ + f"""be changed.""") + + model = Bert({**base_args, **kwargs}) # TODO(jbischof): add some documentation or magic to load our checkpoints # TODO(jbischof): attach the tokenizer From 8fde61650a31a996e8f9fc96ab6824ceb39da88f Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 21:15:03 +0000 Subject: [PATCH 34/45] Tiny fix --- keras_nlp/models/bert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 1655c3b67b..f93a9c20a7 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -287,6 +287,7 @@ def BertBase(**kwargs): [1] * 512, shape=(1, 512)), } output = encoder(input_data) + ``` """ base_args = { From f3751753657a192de17f920f2bbc52c6d7f40e0f Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 21:15:25 +0000 Subject: [PATCH 35/45] Tiny fix --- keras_nlp/models/bert.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index f93a9c20a7..b7ebd6e2e0 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -14,8 +14,8 @@ """Bert model and layer implementations.""" import tensorflow as tf -from tensorflow import keras from absl import logging +from tensorflow import keras from keras_nlp.layers import PositionEmbedding from keras_nlp.layers import TransformerEncoder @@ -35,7 +35,7 @@ class Bert(keras.Model): or classification task networks. This class gives a fully configurable Bert model with any number of layers, - heads, and embedding dimensions. For specific specific bert architectures + heads, and embedding dimensions. For specific specific bert architectures defined in the paper, see for example `keras_nlp.models.BertBase`. Args: @@ -305,7 +305,8 @@ def BertBase(**kwargs): if arg in base_args: logging.error( f"""`{arg}` fixed to {base_args[arg]} in BertBase and cannot """ - f"""be changed.""") + f"""be changed.""" + ) model = Bert({**base_args, **kwargs}) From 0d8d1d58d0d08da0f5696e4bc6c547be3b422531 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 21:40:37 +0000 Subject: [PATCH 36/45] Clarifying comments in `dim` args --- keras_nlp/models/bert.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index b7ebd6e2e0..99511657f0 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -43,7 +43,7 @@ class Bert(keras.Model): num_layers: The number of transformer layers. num_heads: The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads. - hidden_dim: The size of the transformer hidden layers. + hidden_dim: The size of the transformer encoding and pooler layers. intermediate_dim: The output dimension of the first Dense layer in a two-layer feedforward network for each transformer. dropout: Dropout probability for the Transformer encoder. @@ -80,7 +80,9 @@ class Bert(keras.Model): ``` """ - # TODO(bischof): add tests + # TODO(jbischof): add tests + # TODO(jbischof): consider changing `intermediate_dim` to less confusing + # name here and in TransformerEncoder def __init__( self, From 7b34f1a9717efc4ef71d37dc27c3f958f708e156 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 21:41:48 +0000 Subject: [PATCH 37/45] Remove unnecessary comment --- keras_nlp/models/bert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 99511657f0..06d787a637 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Bert model and layer implementations.""" import tensorflow as tf from absl import logging From ae000032317f1085a4c6cd41b2739aec682fcae4 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 21:48:46 +0000 Subject: [PATCH 38/45] Add typehints in the comments. --- keras_nlp/models/bert.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 06d787a637..7271cba510 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -38,21 +38,21 @@ class Bert(keras.Model): defined in the paper, see for example `keras_nlp.models.BertBase`. Args: - vocabulary_size: The size of the token vocabulary. - num_layers: The number of transformer layers. - num_heads: The number of attention heads for each transformer. + vocabulary_size: Int. The size of the token vocabulary. + num_layers: Int. The number of transformer layers. + num_heads: Int. The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads. - hidden_dim: The size of the transformer encoding and pooler layers. - intermediate_dim: The output dimension of the first Dense layer in a - two-layer feedforward network for each transformer. - dropout: Dropout probability for the Transformer encoder. - max_sequence_length: The maximum sequence length that this encoder can - consume. If None, max_sequence_length uses the value from sequence - length. This determines the variable shape for positional + hidden_dim: Int. The size of the transformer encoding and pooler layers. + intermediate_dim: Int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + dropout: Float. Dropout probability for the Transformer encoder. + max_sequence_length: Int. The maximum sequence length that this encoder + can consume. If None, `max_sequence_length` uses the value from + sequence length. This determines the variable shape for positional embeddings. - num_segments: The number of types that the 'segment_ids' input can + num_segments: Int. The number of types that the 'segment_ids' input can take. - cls_token_index: Index of [CLS] token in the vocabulary. + cls_token_index: Int. Index of [CLS] token in the vocabulary. Example usage: ```python @@ -81,7 +81,7 @@ class Bert(keras.Model): # TODO(jbischof): add tests # TODO(jbischof): consider changing `intermediate_dim` to less confusing - # name here and in TransformerEncoder + # name here and in TransformerEncoder (`feed_forward_dim`?) def __init__( self, @@ -214,7 +214,7 @@ class BertClassifier(keras.Model): Args: base_model: A `keras_nlp.models.Bert` to encode inputs. - num_classes: Number of classes to predict. + num_classes: Int. Number of classes to predict. Example usage: ``` From b908604b0fb4b2066024ca6dc095c0fd8924d8be Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 22:11:52 +0000 Subject: [PATCH 39/45] Restore comment --- keras_nlp/models/bert.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 7271cba510..164d6ddd97 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Bert model configurable class, preconfigured versions, and task heads.""" + import tensorflow as tf from absl import logging from tensorflow import keras From 628bda3d59ce0605be088eadc8ee76ba73d740c6 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 22:45:40 +0000 Subject: [PATCH 40/45] Improve handling of `super` args --- keras_nlp/models/bert.py | 46 +++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index 164d6ddd97..f668572a99 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -55,6 +55,9 @@ class Bert(keras.Model): num_segments: Int. The number of types that the 'segment_ids' input can take. cls_token_index: Int. Index of [CLS] token in the vocabulary. + name: String, optional. Name of the model. + trainable: Boolean, optional. If the model's variables should be + trainable. Example usage: ```python @@ -96,7 +99,8 @@ def __init__( max_sequence_length=512, num_segments=2, cls_token_index=0, - **kwargs, + name=None, + trainable=True, ): token_id_input = keras.Input( @@ -178,7 +182,8 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, - **kwargs, + name=name, + trainable=trainable, ) # All references to `self` below this line self.token_embedding = token_embedding_layer @@ -265,7 +270,7 @@ def __init__( self.num_classes = num_classes -def BertBase(**kwargs): +def BertBase(name=None, trainable=True): """Bi-directional Transformer-based encoder network (Bert) using "Base" architecture. @@ -275,6 +280,11 @@ def BertBase(**kwargs): embedding lookups and transformer layers, but not the masked language model or classification task networks. + Args: + name: String, optional. Name of the model. + trainable: Boolean, optional. If the model's variables should be + trainable. + Example usage: ```python # Randomly initialized BertBase encoder @@ -293,25 +303,17 @@ def BertBase(**kwargs): ``` """ - base_args = { - "vocabulary_size": 30522, - "num_layers": 12, - "num_heads": 12, - "hidden_dim": 768, - "intermediate_dim": 3072, - "dropout": 0.1, - "max_sequence_length": 512, - "num_segments": 2, - } - - for arg in kwargs: - if arg in base_args: - logging.error( - f"""`{arg}` fixed to {base_args[arg]} in BertBase and cannot """ - f"""be changed.""" - ) - - model = Bert({**base_args, **kwargs}) + model = Bert( + vocab_size=30522, + num_layers=12, + hidden_size=768, + num_heads=12, + intermediate_dim=3072, + dropout=0.1, + max_sequence_length=512, + name=name, + trainable=trainable, + ) # TODO(jbischof): add some documentation or magic to load our checkpoints # TODO(jbischof): attach the tokenizer From ce8b2c46c4c6924008b18b24a2426f4ba30da51c Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 23:12:03 +0000 Subject: [PATCH 41/45] Initial tests for model call --- keras_nlp/models/bert.py | 5 +-- keras_nlp/models/bert_test.py | 78 +++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 keras_nlp/models/bert_test.py diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index f668572a99..d3067473f9 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -15,7 +15,6 @@ """Bert model configurable class, preconfigured versions, and task heads.""" import tensorflow as tf -from absl import logging from tensorflow import keras from keras_nlp.layers import PositionEmbedding @@ -304,10 +303,10 @@ def BertBase(name=None, trainable=True): """ model = Bert( - vocab_size=30522, + vocabulary_size=30522, num_layers=12, - hidden_size=768, num_heads=12, + hidden_dim=768, intermediate_dim=3072, dropout=0.1, max_sequence_length=512, diff --git a/keras_nlp/models/bert_test.py b/keras_nlp/models/bert_test.py new file mode 100644 index 0000000000..54c2af5258 --- /dev/null +++ b/keras_nlp/models/bert_test.py @@ -0,0 +1,78 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Bert model.""" + +import tensorflow as tf + +from keras_nlp.models import bert + + +class BertTest(tf.test.TestCase): + def test_valid_call_bert(self): + model = bert.Bert( + vocabulary_size=30522, + num_layers=12, + num_heads=12, + hidden_dim=768, + intermediate_dim=3072, + max_sequence_length=12, + name="encoder", + ) + input_data = { + "input_ids": tf.random.uniform( + shape=(1, 12), dtype=tf.int64, maxval=30522 + ), + "segment_ids": tf.constant( + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) + ), + "input_mask": tf.constant( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) + ), + } + model(input_data) + + def test_valid_call_classifier(self): + model = bert.Bert( + vocabulary_size=30522, + num_layers=12, + num_heads=12, + hidden_dim=768, + intermediate_dim=3072, + max_sequence_length=12, + name="encoder", + ) + input_data = { + "input_ids": tf.random.uniform( + shape=(1, 12), dtype=tf.int64, maxval=30522 + ), + "segment_ids": tf.constant( + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) + ), + "input_mask": tf.constant( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) + ), + } + classifier = bert.BertClassifier(model, 4, name="classifier") + classifier(input_data) + + def test_valid_call_bert_base(self): + model = bert.BertBase(name="encoder") + input_data = { + "input_ids": tf.random.uniform( + shape=(1, 512), dtype=tf.int64, maxval=model.vocabulary_size + ), + "segment_ids": tf.constant([0] * 200 + [1] * 312, shape=(1, 512)), + "input_mask": tf.constant([1] * 512, shape=(1, 512)), + } + model(input_data) From 49fd2e57a6b77cd40e81dd2bb183b873e3dc03cb Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 23:19:38 +0000 Subject: [PATCH 42/45] Make kwargs passing consistent --- keras_nlp/models/bert.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index d3067473f9..acac424c1f 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -221,6 +221,9 @@ class BertClassifier(keras.Model): Args: base_model: A `keras_nlp.models.Bert` to encode inputs. num_classes: Int. Number of classes to predict. + name: String, optional. Name of the model. + trainable: Boolean, optional. If the model's variables should be + trainable. Example usage: ``` @@ -253,7 +256,8 @@ def __init__( self, base_model, num_classes, - **kwargs, + name=None, + trainable=True, ): inputs = base_model.input pooled = base_model(inputs)["pooled_output"] @@ -263,7 +267,9 @@ def __init__( name="logits", )(pooled) # Instantiate using Functional API Model constructor - super().__init__(inputs=inputs, outputs=outputs, **kwargs) + super().__init__( + inputs=inputs, outputs=outputs, name=name, trainable=trainable + ) # All references to `self` below this line self.base_model = base_model self.num_classes = num_classes From ddb59f1701793da68653b348dd9fc29183f7fed2 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Mon, 15 Aug 2022 23:46:20 +0000 Subject: [PATCH 43/45] Saving model test --- keras_nlp/models/bert_test.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/keras_nlp/models/bert_test.py b/keras_nlp/models/bert_test.py index 54c2af5258..a329e0f054 100644 --- a/keras_nlp/models/bert_test.py +++ b/keras_nlp/models/bert_test.py @@ -13,7 +13,10 @@ # limitations under the License. """Tests for Bert model.""" +import os + import tensorflow as tf +from tensorflow import keras from keras_nlp.models import bert @@ -76,3 +79,34 @@ def test_valid_call_bert_base(self): "input_mask": tf.constant([1] * 512, shape=(1, 512)), } model(input_data) + + def test_saving_model(self): + model = bert.Bert( + vocabulary_size=30522, + num_layers=12, + num_heads=12, + hidden_dim=768, + intermediate_dim=3072, + max_sequence_length=12, + name="encoder", + ) + input_data = { + "input_ids": tf.random.uniform( + shape=(1, 12), dtype=tf.int64, maxval=30522 + ), + "segment_ids": tf.constant( + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) + ), + "input_mask": tf.constant( + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12) + ), + } + model_output = model(input_data) + save_path = os.path.join(self.get_temp_dir(), "model") + model.save(save_path) + restored_model = keras.models.load_model(save_path) + + restored_output = restored_model(input_data) + self.assertAllClose( + model_output["pooled_output"], restored_output["pooled_output"] + ) From cfdbbb147b7e78189c41cea657db16073e1d6405 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Tue, 16 Aug 2022 00:38:22 +0000 Subject: [PATCH 44/45] Fix TODOs --- keras_nlp/models/bert.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index acac424c1f..d2a15bb59b 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -83,9 +83,8 @@ class Bert(keras.Model): ``` """ - # TODO(jbischof): add tests - # TODO(jbischof): consider changing `intermediate_dim` to less confusing - # name here and in TransformerEncoder (`feed_forward_dim`?) + # TODO(jbischof): consider changing `intermediate_dim` and `hidden_dim` to + # less confusing name here and in TransformerEncoder (`feed_forward_dim`?) def __init__( self, From 091ae34d75b4ba087d5caadcc9e518df450ff335 Mon Sep 17 00:00:00 2001 From: Jonathan Bischof Date: Tue, 16 Aug 2022 00:47:55 +0000 Subject: [PATCH 45/45] Format fix --- keras_nlp/models/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/bert.py b/keras_nlp/models/bert.py index d2a15bb59b..adaf3d9305 100644 --- a/keras_nlp/models/bert.py +++ b/keras_nlp/models/bert.py @@ -83,7 +83,7 @@ class Bert(keras.Model): ``` """ - # TODO(jbischof): consider changing `intermediate_dim` and `hidden_dim` to + # TODO(jbischof): consider changing `intermediate_dim` and `hidden_dim` to # less confusing name here and in TransformerEncoder (`feed_forward_dim`?) def __init__(