Skip to content

Add pretrained checkpoints for Bert #297

@jbischof

Description

@jbischof

Problem

We exposed an API for fully configurable and fixed Bert architectures in PR #288. Most users will not have the resources to do pretraining on their own, however. We should offer checkpoints for models.BertBase similar to those used in the original paper.

Implementation

Weights

Model checkpoints from the original paper repo (gh) or similar source can be converted to be compatible with our implementation in PR #288. This code will be a one-off task and not added to the repo.

Format

keras-cv is currently using h5, but that format is deprecated. An obvious alternative is a tf.checkpoint using model.save_weights(). The resulting data and index files will have to be zipped and uploaded to our cloud storage bucket.

API

There some prior art in keras-cv and our pretrained vocab to use a string to identify a model of interest which is then checked against a dictionary (e.g., weights.py). However, this requires the user to dive into the code to find possible names and a lot of input validation and error handling in the implementation.

I propose to create an enum for each Model class (e.g., models.BertBase) with possible checkpoints that will expose these options as part of the API and work with users' autocomplete engines:

class BertBaseCkpt(Enum):
   uncased = 0
   cased = 1
   zh = 2

Under the hood this enum would then be the key for a dictionary of metadata for each model including checkpoint location, vocabulary location, and a description string explaining where the model came from and what data was used to train it. We could also write a method to print the description for each/all enum value(s) to again protect the user from having to read through utility functions.

base_models = {
   BertBaseCkpt.uncased: {
      "ckpt": ...
      "description": ...
      "vocab": ...
   }
}

Model checkpoints will be hosted in our Google Cloud storage bucket and loaded behind the scenes given an enum value. The user should never have to interact with this dictionary.

Here is the enhanced API:

def BertBase(ckpt=None, name=None, trainable=True):
    """Bi-directional Transformer-based encoder network (Bert) using "Base"
    architecture.

    This network implements a bi-directional Transformer-based encoder as
    described in ["BERT: Pre-training of Deep Bidirectional Transformers for
    Language Understanding"](https://arxiv.org/abs/1810.04805). It includes the
    embedding lookups and transformer layers, but not the masked language model
    or classification task networks.

    Args:
        ckpt: BertBaseCkpt. A checkpoint of pretrained weights to load into the model. 
             If `None`, the model will be randomly initialized.
        name: String, optional. Name of the model.
        trainable: Boolean, optional. If the model's variables should be
            trainable.

Vocabulary

Note that this proposal does not cover the vocabulary file, which will be incorporated in our preprocessing work. One potential proposal is to host the vocab.txt file used to train the model in the same folder to be loaded if the user does not supply their own. The rationale is that except for very advanced use cases the vocabulary must match that used in training, and requiring the user to specify their own will expose our Cloud storage implementation to the user for little gain. However, this is beyond the scope of the proposal.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions