Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
de94eab
first draft
jbischof Aug 5, 2022
d3934b7
first working version
jbischof Aug 5, 2022
5b8572a
Partial draft of functional API + BertBase
jbischof Aug 9, 2022
7add3a7
Change Bert to Model subclass API
jbischof Aug 9, 2022
2cfa58c
Get fine-tuning script working
jbischof Aug 10, 2022
9eeff1a
Move pretraining head back to `examples/`
jbischof Aug 10, 2022
483d781
Rename to BertPretrainingModel
jbischof Aug 10, 2022
2d9fca7
Small notes
jbischof Aug 10, 2022
e54e39c
Formatting and notes
jbischof Aug 10, 2022
1ac992b
Note
jbischof Aug 10, 2022
66b5c7c
Move Bert to models/ folder
jbischof Aug 10, 2022
b3d22c0
Small style changes re: comments
jbischof Aug 10, 2022
b704f28
Fix Bert docstrings and remove `weights` param
jbischof Aug 10, 2022
4f1ecdd
Initialization and docstring for classifier
jbischof Aug 11, 2022
caf9dc5
Decouple finetuning from model config
jbischof Aug 11, 2022
772f201
Replace `inner` -> `intermediate`
jbischof Aug 11, 2022
9bbc465
Do not expose initializer and activation
jbischof Aug 11, 2022
50774e9
Do not exposed initializer and activation
jbischof Aug 11, 2022
ba4b779
Move BertClassifier to functional API
jbischof Aug 11, 2022
7a75f4c
Fix token embedding exposure in encoder
jbischof Aug 11, 2022
33e7d21
Fix imports in `examples/`
jbischof Aug 12, 2022
7ef349c
remove TODO
jbischof Aug 12, 2022
c596267
Add docstring test for Bert encoder
jbischof Aug 12, 2022
52d08a0
Format
jbischof Aug 12, 2022
3cbf709
Add docstring test for classifier
jbischof Aug 12, 2022
b2d6053
Format
jbischof Aug 12, 2022
5fc84ed
Set `max_sequence_length` for BertBase
jbischof Aug 12, 2022
a4c7f8d
Move TODO
jbischof Aug 12, 2022
94f81b2
Standarize initializers to match Bert paper
jbischof Aug 12, 2022
b99a84c
Respond to minor comments
jbischof Aug 12, 2022
a8a6dc8
More minor comment fixes
jbischof Aug 12, 2022
4053bcb
Format fix
jbischof Aug 12, 2022
0775307
Improve documentation
jbischof Aug 15, 2022
8fde616
Tiny fix
jbischof Aug 15, 2022
f375175
Tiny fix
jbischof Aug 15, 2022
0d8d1d5
Clarifying comments in `dim` args
jbischof Aug 15, 2022
7b34f1a
Remove unnecessary comment
jbischof Aug 15, 2022
ae00003
Add typehints in the comments.
jbischof Aug 15, 2022
b908604
Restore comment
jbischof Aug 15, 2022
628bda3
Improve handling of `super` args
jbischof Aug 15, 2022
ce8b2c4
Initial tests for model call
jbischof Aug 15, 2022
49fd2e5
Make kwargs passing consistent
jbischof Aug 15, 2022
ddb59f1
Saving model test
jbischof Aug 15, 2022
cfdbbb1
Fix TODOs
jbischof Aug 16, 2022
091ae34
Format fix
jbischof Aug 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 19 additions & 30 deletions examples/bert/bert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,60 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO(jbischof): remove in favor of BertBase, BertSmall, etc
MODEL_CONFIGS = {
"tiny": {
"num_layers": 2,
"hidden_size": 128,
"hidden_dim": 128,
"dropout": 0.1,
"num_attention_heads": 2,
"inner_size": 512,
"inner_activation": "gelu",
"initializer_range": 0.02,
"num_heads": 2,
"intermediate_dim": 512,
},
"mini": {
"num_layers": 4,
"hidden_size": 256,
"hidden_dim": 256,
"dropout": 0.1,
"num_attention_heads": 4,
"inner_size": 1024,
"inner_activation": "gelu",
"initializer_range": 0.02,
"num_heads": 4,
"intermediate_dim": 1024,
},
"small": {
"num_layers": 4,
"hidden_size": 512,
"hidden_dim": 512,
"dropout": 0.1,
"num_attention_heads": 8,
"inner_size": 2048,
"inner_activation": "gelu",
"initializer_range": 0.02,
"num_heads": 8,
"intermediate_dim": 2048,
},
"medium": {
"num_layers": 8,
"hidden_size": 512,
"hidden_dim": 512,
"dropout": 0.1,
"num_attention_heads": 8,
"inner_size": 2048,
"inner_activation": "gelu",
"initializer_range": 0.02,
"num_heads": 8,
"intermediate_dim": 2048,
},
"base": {
"num_layers": 12,
"hidden_size": 768,
"hidden_dim": 768,
"dropout": 0.1,
"num_attention_heads": 12,
"inner_size": 3072,
"inner_activation": "gelu",
"initializer_range": 0.02,
"num_heads": 12,
"intermediate_dim": 3072,
},
"large": {
"num_layers": 24,
"hidden_size": 1024,
"hidden_dim": 1024,
"dropout": 0.1,
"num_attention_heads": 16,
"inner_size": 4096,
"inner_activation": "gelu",
"initializer_range": 0.02,
"num_heads": 16,
"intermediate_dim": 4096,
},
}

Expand Down
40 changes: 3 additions & 37 deletions examples/bert/bert_finetune_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,10 @@

import keras_nlp
from examples.bert.bert_config import FINETUNING_CONFIG
from examples.bert.bert_config import MODEL_CONFIGS
from examples.bert.bert_config import PREPROCESSING_CONFIG

FLAGS = flags.FLAGS

flags.DEFINE_string(
"model_size",
"tiny",
"One of: tiny, mini, small, medium, base, or large.",
)

flags.DEFINE_string(
"vocab_file",
None,
Expand Down Expand Up @@ -109,39 +102,14 @@ def to_tf_dataset(split):
return train_ds, test_ds, validation_ds


class BertClassificationFinetuner(keras.Model):
"""Adds a classification head to a pre-trained BERT model for finetuning"""

def __init__(self, bert_model, num_classes, initializer, **kwargs):
super().__init__(**kwargs)
self.bert_model = bert_model
self._logit_layer = keras.layers.Dense(
num_classes,
kernel_initializer=initializer,
name="logits",
)

def call(self, inputs):
# Ignore the sequence output, use the pooled output.
_, pooled_output = self.bert_model(inputs)
return self._logit_layer(pooled_output)


class BertHyperModel(keras_tuner.HyperModel):
"""Creates a hypermodel to help with the search space for finetuning."""

def __init__(self, model_config):
self.model_config = model_config

def build(self, hp):
model = keras.models.load_model(FLAGS.saved_model_input, compile=False)
model_config = self.model_config
finetuning_model = BertClassificationFinetuner(
bert_model=model,
finetuning_model = keras_nlp.models.BertClassifier(
base_model=model,
num_classes=3 if FLAGS.task_name in ("mnli", "ax") else 2,
initializer=keras.initializers.TruncatedNormal(
stddev=model_config["initializer_range"]
),
)
finetuning_model.compile(
optimizer=keras.optimizers.Adam(
Expand All @@ -168,8 +136,6 @@ def main(_):
end_value=tokenizer.token_to_id("[SEP]"),
)

model_config = MODEL_CONFIGS[FLAGS.model_size]

def preprocess_data(inputs, labels):
inputs = [tokenizer(x) for x in inputs]
token_ids, segment_ids = packer(inputs)
Expand All @@ -194,7 +160,7 @@ def preprocess_data(inputs, labels):
)

# Create a hypermodel object for a RandomSearch.
hypermodel = BertHyperModel(model_config)
hypermodel = BertHyperModel()

# Initialize the random search over the 4 learning rate parameters, for 4
# trials and 3 epochs for each trial.
Expand Down
200 changes: 0 additions & 200 deletions examples/bert/bert_model.py

This file was deleted.

3 changes: 3 additions & 0 deletions examples/bert/bert_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ def create_masked_lm_predictions(
):
"""Creates the predictions for the masked LM objective."""

# TODO(jbischof): replace with keras_nlp.layers.MLMMaskGenerator
# (Issue #166)

cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
Expand Down
Loading