diff --git a/README.md b/README.md index e762a492f..a02a7cc4a 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes 6. Learning rate scheduler, meta init, Optional Fused RMSNorm 7. All options easily configured via [toml files](train_configs/) 8. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning +9. [Float8 support](docs/float8.md) We report our [Performance](docs/performance.md) verified on 64 A100 GPUs @@ -50,11 +51,10 @@ We report our [Performance](docs/performance.md) verified on 64 A100 GPUs ### Coming soon 1. Async checkpointing -2. Float8 support -3. Context Parallel -4. 3D Pipeline Parallel -5. `torch.compile` support -6. Scalable data loading solution +2. Context Parallel +3. 3D Pipeline Parallel +4. `torch.compile` support +5. Scalable data loading solution ## Installation diff --git a/docs/float8.md b/docs/float8.md new file mode 100644 index 000000000..ad481c5e8 --- /dev/null +++ b/docs/float8.md @@ -0,0 +1,18 @@ +## Enable Float8 Training on H100s + +Please install latest [TorchAO](https://github.com/pytorch/ao/tree/main/torchao/float8) to support float8 dtype +``` +USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git +``` + +Launch training job with the following command (or alternatively set configs in toml files) +``` +CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp +``` +* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. +* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. +* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. + +For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`). + +For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration.