diff --git a/examples/bert/main.py b/examples/bert/main.py index 721cfa4f4..ae624a2b6 100644 --- a/examples/bert/main.py +++ b/examples/bert/main.py @@ -47,6 +47,11 @@ def main(cfg: DictConfig, # Get batch size info cfg = update_batch_size_info(cfg) + # Read FSDP Config as a dict + fsdp_config = cfg.get('fsdp_config', None) + fsdp_config = om.to_container(fsdp_config, + resolve=True) if fsdp_config else None + # Build Model print('Initializing model...') model = build_model(cfg.model) @@ -110,6 +115,7 @@ def main(cfg: DictConfig, device=cfg.get('device', None), device_train_microbatch_size=cfg.get('device_train_microbatch_size', 'auto'), + fsdp_config=fsdp_config, # type: ignore save_folder=cfg.get('save_folder', None), save_interval=cfg.get('save_interval', '1000ba'), save_num_checkpoints_to_keep=cfg.get('save_num_checkpoints_to_keep', diff --git a/examples/bert/src/bert_layers.py b/examples/bert/src/bert_layers.py index 8055e0c86..e37f695bb 100644 --- a/examples/bert/src/bert_layers.py +++ b/examples/bert/src/bert_layers.py @@ -915,6 +915,14 @@ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, return {'input_ids': input_ids, 'attention_mask': attention_mask} + # FSDP Wrap function + def fsdp_wrap_fn(self, module): + return isinstance(module, BertLayer) + + # Activation Checkpointing + def activation_checkpointing_fn(self, module): + return isinstance(module, BertLayer) + class BertForNextSentencePrediction(BertPreTrainedModel): #TBD: Push in future commit diff --git a/examples/bert/yamls/main/hf-bert-base-uncased.yaml b/examples/bert/yamls/main/hf-bert-base-uncased.yaml index 61958d921..da1858f50 100644 --- a/examples/bert/yamls/main/hf-bert-base-uncased.yaml +++ b/examples/bert/yamls/main/hf-bert-base-uncased.yaml @@ -25,8 +25,8 @@ model: model_config: num_attention_heads: 12 # bert-base default num_hidden_layers: 12 # bert-base default - max_position_embedding: 512 - attention_probs_dropout_prob: 0.0 + max_position_embedding: 512 # bert-base default + attention_probs_dropout_prob: 0.1 # bert-base default # Dataloaders train_loader: