Skip to content

Commit 05db84d

Browse files
authored
add float8 to README (#509)
add float8 link in README so we can redirect people from dev-discuss post to torchtitan repo README looks like this after rendering <img width="518" alt="Screenshot 2024-08-06 at 5 42 10 PM" src="https://github.com/user-attachments/assets/50af99d7-93be-459a-89d7-8c08b8fb95d4"> float8.md looks like this <img width="563" alt="Screenshot 2024-08-06 at 5 04 17 PM" src="https://github.com/user-attachments/assets/06d30aad-4133-4cec-9037-cfcf155b45c4"> I tried the command locally and traces are looking good <img width="726" alt="Screenshot 2024-08-06 at 5 00 00 PM" src="https://github.com/user-attachments/assets/bdfa3d7e-efe1-4009-92a1-0f5c310013fb">
1 parent c9184b9 commit 05db84d

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes
4343
6. Learning rate scheduler, meta init, Optional Fused RMSNorm
4444
7. All options easily configured via [toml files](train_configs/)
4545
8. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
46+
9. [Float8 support](docs/float8.md)
4647

4748
We report our [Performance](docs/performance.md) verified on 64 A100 GPUs
4849

4950

5051
### Coming soon
5152

5253
1. Async checkpointing
53-
2. Float8 support
54-
3. Context Parallel
55-
4. 3D Pipeline Parallel
56-
5. `torch.compile` support
57-
6. Scalable data loading solution
54+
2. Context Parallel
55+
3. 3D Pipeline Parallel
56+
4. `torch.compile` support
57+
5. Scalable data loading solution
5858

5959

6060
## Installation

docs/float8.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
## Enable Float8 Training on H100s
2+
3+
Please install latest [TorchAO](https://github.com/pytorch/ao/tree/main/torchao/float8) to support float8 dtype
4+
```
5+
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
6+
```
7+
8+
Launch training job with the following command (or alternatively set configs in toml files)
9+
```
10+
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
11+
```
12+
* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
13+
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
14+
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
15+
16+
For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`).
17+
18+
For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration.

0 commit comments

Comments
 (0)