Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f86dd67

Browse files
vkuzofacebook-github-bot
authored andcommitted
Update README.md (#189)
Summary: Add more product context, link to feature tracker, fix some typos Pull Request resolved: #189 Reviewed By: drisspg, malfet Differential Revision: D52806367 Pulled By: vkuzo fbshipit-source-id: f6d9b549ae697cf75cb00d0ee7e03989e3c4175c
1 parent d272138 commit f86dd67

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

README.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
# float8_experimental
22

3-
This is a prototype of a float8 training UX in native PyTorch, with full torch.compile and distributed support.
3+
This is an early version of a library for accelerating training with float8 in native PyTorch
4+
according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
45
The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling.
6+
``torch.compile`` is supported out of the box. With ``torch.compile`` on, initial results show
7+
throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
58

6-
Backwards compatibility is not guaranteed at this point. The codebase is in active development and
7-
will change rapidly.
9+
:warning: <em>See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features. Key features such as weight cast recomputation in backward and large scale distributed support are not ready yet. </em>
10+
11+
:warning: <em>Backwards compatibility is not guaranteed at this point. The codebase is in active development and
12+
will change rapidly.</em>
813

914
# installation
1015

16+
:warning: <em>For now, use the latest PyTorch nightly for best results with torch.compile.</em>
17+
1118
```Shell
1219
pip install .
1320

@@ -18,9 +25,9 @@ pip install -e .
1825
pip install -e ".[dev]"
1926
```
2027

21-
# User API, subject to change
28+
# User API
2229

23-
We provide two scaling strategies: per-tensor dynamic and delayed.
30+
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.
2431

2532
## float8 linear with dynamic scaling
2633

@@ -61,7 +68,7 @@ m = Model(...)
6168
swap_linear_with_float8_linear(m, Float8Linear)
6269

6370
# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
64-
# config.enable_pre_and_post_forward are needed for autocast+compile+FSDP+float8 to work
71+
# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
6572
from float8_experimental import config
6673
config.enable_amax_init = False # only needed for autocast + compile + FSDP + float8 delayed
6774
config.enable_pre_and_post_forward = False # only needed for autocast + compile + FSDP + float8 delayed
@@ -103,7 +110,7 @@ pytest test/test_compile.py
103110
# run a two-GPU integration test on FSDP
104111
./test/test_fsdp.sh
105112

106-
# run integration tests for TP/SP
113+
# run integration tests for TP/SP (outdated)
107114
./test/test_tp.sh
108115

109116
# run all of these tests
@@ -116,7 +123,7 @@ pytest test/test_compile.py
116123
# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
117124
./benchmarks/bench_matmul.py
118125

119-
# benchmark fw/bw of `Linear`, `Float8Linear` on LLaMa 2 70B shapes
126+
# benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes
120127
# make sure to turn on torch.compile to get the best performance
121128
./benchmarks/bench_linear_float8.py -o ../tmp/test.txt --compile
122129
```

0 commit comments

Comments
 (0)