You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Copy file name to clipboardExpand all lines: README.md
+15-8Lines changed: 15 additions & 8 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,13 +1,20 @@
1
1
# float8_experimental
2
2
3
-
This is a prototype of a float8 training UX in native PyTorch, with full torch.compile and distributed support.
3
+
This is an early version of a library for accelerating training with float8 in native PyTorch
4
+
according to the recipes laid out in https://arxiv.org/pdf/2209.05433.pdf.
4
5
The codebase strives to stay small, easily hackable, and debuggable with native PyTorch tooling.
6
+
``torch.compile`` is supported out of the box. With ``torch.compile`` on, initial results show
7
+
throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs.
5
8
6
-
Backwards compatibility is not guaranteed at this point. The codebase is in active development and
7
-
will change rapidly.
9
+
:warning: <em>See the [feature tracker](https://github.com/pytorch-labs/float8_experimental/issues/187) for upcoming features. Key features such as weight cast recomputation in backward and large scale distributed support are not ready yet. </em>
10
+
11
+
:warning: <em>Backwards compatibility is not guaranteed at this point. The codebase is in active development and
12
+
will change rapidly.</em>
8
13
9
14
# installation
10
15
16
+
:warning: <em>For now, use the latest PyTorch nightly for best results with torch.compile.</em>
17
+
11
18
```Shell
12
19
pip install .
13
20
@@ -18,9 +25,9 @@ pip install -e .
18
25
pip install -e ".[dev]"
19
26
```
20
27
21
-
# User API, subject to change
28
+
# User API
22
29
23
-
We provide two scaling strategies: per-tensor dynamic and delayed.
30
+
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.
24
31
25
32
## float8 linear with dynamic scaling
26
33
@@ -61,7 +68,7 @@ m = Model(...)
61
68
swap_linear_with_float8_linear(m, Float8Linear)
62
69
63
70
# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
64
-
# config.enable_pre_and_post_forward are needed for autocast+compile+FSDP+float8 to work
71
+
# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
65
72
from float8_experimental import config
66
73
config.enable_amax_init =False# only needed for autocast + compile + FSDP + float8 delayed
67
74
config.enable_pre_and_post_forward =False# only needed for autocast + compile + FSDP + float8 delayed
@@ -103,7 +110,7 @@ pytest test/test_compile.py
103
110
# run a two-GPU integration test on FSDP
104
111
./test/test_fsdp.sh
105
112
106
-
# run integration tests for TP/SP
113
+
# run integration tests for TP/SP (outdated)
107
114
./test/test_tp.sh
108
115
109
116
# run all of these tests
@@ -116,7 +123,7 @@ pytest test/test_compile.py
116
123
# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
117
124
./benchmarks/bench_matmul.py
118
125
119
-
# benchmark fw/bw of `Linear`, `Float8Linear` on LLaMa 2 70B shapes
126
+
# benchmark fw/bw of `Linear` and `Float8Linear` on LLaMa 2 70B shapes
120
127
# make sure to turn on torch.compile to get the best performance
0 commit comments