Skip to content

Commit ff53569

Browse files
committed
Update on "[BE][5/n] simplify pp vs. non-pp set up"
This PR refactors the PP vs. non-PP setup in `train.py`: - moves `build_pipeline_schedule ` into `pipeline_llama` which reduces the interface for PP in `train.py` - refactors the set up flow, so that we only have two main if-else for PP vs. non-PP, one in setup phase, the other in training phase. - I think it's already clear to read or copy-paste, and it's not necessary to create separate sub-functions to hold the code. This PR also removes unnecessary module returns in `parallelize_llama`, as we are modifying module in-place. Note that torch.compile and AC require returning and reassigning the module. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by ``` transformer_block = compile/AC(transformer_block) model.layers.register_module(layer_id, transformer_block) ``` [ghstack-poisoned]
1 parent f58ca70 commit ff53569

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def main(job_config: JobConfig):
133133
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
134134
)
135135

136-
# loss function to be shared by Pipeline Parallel and spmd training
136+
# loss function to be shared by Pipeline Parallel and SPMD training
137137
def loss_fn(pred, labels):
138138
return torch.nn.functional.cross_entropy(
139139
pred.flatten(0, 1), labels.flatten(0, 1)
@@ -150,7 +150,7 @@ def loss_fn(pred, labels):
150150
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
151151
# optimizer, and checkpointing
152152
for m in model_parts:
153-
# apply spmd-style PT-D techniques
153+
# apply SPMD-style PT-D techniques
154154
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
155155

156156
# In PP, we cannot call init_weights directly because some layers are missing.
@@ -269,7 +269,7 @@ def loss_fn(pred, labels):
269269
optimizers.zero_grad()
270270

271271
if parallel_dims.pp_enabled:
272-
# pipeline parallel forward / backward inside step() call
272+
# Pipeline Parallel forward / backward inside step() call
273273
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
274274

275275
with train_context():

0 commit comments

Comments
 (0)