You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
support float8 weight caching for gradient accumulation/PP (#164)
Summary:
In the cases where the optimizer update does not happen after every forward such as microbatching/PP, we can save the casted weight to trade some time for memory.
For now I'm just testing out performance+accuracy. We can improve on the API in future PRs. The current code is torch.compile friendly which is nice.
In terms of accuracy this should be no change, I will validate this further if we want to land this.
For performance, on drisspg's LLaMa 7B pretrain script, with bsz==128 and micro_bsz == 1:
1. baseline bf16 + compile: 2.38 it/s
2. delayed scaling + compile: 2.80 it/s (1.18x over baseline)
3. delayed scaling + compile + this PR: 3.04 it/s (1.28x over baseline)
Pull Request resolved: #164
Test Plan:
```
pytest test/test_base.py -s -k test_weight_caching
```
Reviewed By: drisspg
Differential Revision: D52356785
Pulled By: vkuzo
fbshipit-source-id: e0173666a6c7639246dfde636734900b9fc1657e
0 commit comments