Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/bert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def main(cfg: DictConfig,
# Get batch size info
cfg = update_batch_size_info(cfg)

# Read FSDP Config as a dict
fsdp_config = cfg.get('fsdp_config', None)
fsdp_config = om.to_container(fsdp_config,
resolve=True) if fsdp_config else None

# Build Model
print('Initializing model...')
model = build_model(cfg.model)
Expand Down Expand Up @@ -110,6 +115,7 @@ def main(cfg: DictConfig,
device=cfg.get('device', None),
device_train_microbatch_size=cfg.get('device_train_microbatch_size',
'auto'),
fsdp_config=fsdp_config, # type: ignore
save_folder=cfg.get('save_folder', None),
save_interval=cfg.get('save_interval', '1000ba'),
save_num_checkpoints_to_keep=cfg.get('save_num_checkpoints_to_keep',
Expand Down
8 changes: 8 additions & 0 deletions examples/bert/src/bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,14 @@ def prepare_inputs_for_generation(self, input_ids: torch.Tensor,

return {'input_ids': input_ids, 'attention_mask': attention_mask}

# FSDP Wrap function
def fsdp_wrap_fn(self, module):
return isinstance(module, BertLayer)

# Activation Checkpointing
def activation_checkpointing_fn(self, module):
return isinstance(module, BertLayer)


class BertForNextSentencePrediction(BertPreTrainedModel):
#TBD: Push in future commit
Expand Down
4 changes: 2 additions & 2 deletions examples/bert/yamls/main/hf-bert-base-uncased.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ model:
model_config:
num_attention_heads: 12 # bert-base default
num_hidden_layers: 12 # bert-base default
max_position_embedding: 512
attention_probs_dropout_prob: 0.0
max_position_embedding: 512 # bert-base default
attention_probs_dropout_prob: 0.1 # bert-base default

# Dataloaders
train_loader:
Expand Down