Skip to content

add float8 to README #509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@ Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes
6. Learning rate scheduler, meta init, Optional Fused RMSNorm
7. All options easily configured via [toml files](train_configs/)
8. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
9. [Float8 support](docs/float8.md)

We report our [Performance](docs/performance.md) verified on 64 A100 GPUs


### Coming soon

1. Async checkpointing
2. Float8 support
3. Context Parallel
4. 3D Pipeline Parallel
5. `torch.compile` support
6. Scalable data loading solution
2. Context Parallel
3. 3D Pipeline Parallel
4. `torch.compile` support
5. Scalable data loading solution


## Installation
Expand Down
18 changes: 18 additions & 0 deletions docs/float8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## Enable Float8 Training on H100s

Please install latest [TorchAO](https://github.com/pytorch/ao/tree/main/torchao/float8) to support float8 dtype
```
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
```

Launch training job with the following command (or alternatively set configs in toml files)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
```
* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.

For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think later we could expand more on parallelisms and have separate subsections for FSDP fp8 and TP fp8, to explain more details for advanced users if necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. will expand the parallelism part into sections


For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration.
Loading