You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update on "[BE][5/n] simplify pp vs. non-pp set up"
This PR refactors the PP vs. non-PP setup in `train.py`:
- moves `build_pipeline_schedule ` into `pipeline_llama` which reduces the interface for PP in `train.py`
- refactors the set up flow, so that we only have two main if-else for PP vs. non-PP, one in setup phase, the other in training phase.
- I think it's already clear to read or copy-paste, and it's not necessary to create separate sub-functions to hold the code.
This PR also removes unnecessary module returns in `parallelize_llama`, as we are modifying module in-place. Note that torch.compile and AC require returning and reassigning the module. But since we are doing per-block compile and AC, we achieve that in-place for the whole model by
```
transformer_block = compile/AC(transformer_block)
model.layers.register_module(layer_id, transformer_block)
```
[ghstack-poisoned]
0 commit comments