-
Notifications
You must be signed in to change notification settings - Fork 296
Description
Problem
We exposed an API for fully configurable and fixed Bert architectures in PR #288. Most users will not have the resources to do pretraining on their own, however. We should offer checkpoints for models.BertBase
similar to those used in the original paper.
Implementation
Weights
Model checkpoints from the original paper repo (gh) or similar source can be converted to be compatible with our implementation in PR #288. This code will be a one-off task and not added to the repo.
Format
keras-cv
is currently using h5
, but that format is deprecated. An obvious alternative is a tf.checkpoint
using model.save_weights()
. The resulting data
and index
files will have to be zipped and uploaded to our cloud storage bucket.
API
There some prior art in keras-cv
and our pretrained vocab to use a string to identify a model of interest which is then checked against a dictionary (e.g., weights.py). However, this requires the user to dive into the code to find possible names and a lot of input validation and error handling in the implementation.
I propose to create an enum
for each Model
class (e.g., models.BertBase
) with possible checkpoints that will expose these options as part of the API and work with users' autocomplete engines:
class BertBaseCkpt(Enum):
uncased = 0
cased = 1
zh = 2
Under the hood this enum
would then be the key for a dictionary of metadata for each model including checkpoint location, vocabulary location, and a description string explaining where the model came from and what data was used to train it. We could also write a method to print the description for each/all enum
value(s) to again protect the user from having to read through utility functions.
base_models = {
BertBaseCkpt.uncased: {
"ckpt": ...
"description": ...
"vocab": ...
}
}
Model checkpoints will be hosted in our Google Cloud storage bucket and loaded behind the scenes given an enum value. The user should never have to interact with this dictionary.
Here is the enhanced API:
def BertBase(ckpt=None, name=None, trainable=True):
"""Bi-directional Transformer-based encoder network (Bert) using "Base"
architecture.
This network implements a bi-directional Transformer-based encoder as
described in ["BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding"](https://arxiv.org/abs/1810.04805). It includes the
embedding lookups and transformer layers, but not the masked language model
or classification task networks.
Args:
ckpt: BertBaseCkpt. A checkpoint of pretrained weights to load into the model.
If `None`, the model will be randomly initialized.
name: String, optional. Name of the model.
trainable: Boolean, optional. If the model's variables should be
trainable.
Vocabulary
Note that this proposal does not cover the vocabulary file, which will be incorporated in our preprocessing work. One potential proposal is to host the vocab.txt
file used to train the model in the same folder to be loaded if the user does not supply their own. The rationale is that except for very advanced use cases the vocabulary must match that used in training, and requiring the user to specify their own will expose our Cloud storage implementation to the user for little gain. However, this is beyond the scope of the proposal.