Skip to content

Commit 34fedff

Browse files
authored
Add 4-bit Adam (#478)
* add 4bit * rename * simplify 4bit * add rank1 scaling * add lpmm to benchmark * remove rank-1 scaling * update * clean * rename * update test * fix * fix * update adam * add AdamW 4bit * update * remove lpmm from dev cuz CI can't compile * fix test * update README * Update README.md * update readme. small fixes * remove zero padding
1 parent 9f85488 commit 34fedff

File tree

12 files changed

+492
-198
lines changed

12 files changed

+492
-198
lines changed

benchmarks/benchmark_adam_8bit.py renamed to benchmarks/benchmark_low_bit_adam.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1-
# pip install timm wandb tqdm datasets
1+
# pip install timm wandb tqdm datasets yacs bitsandbytes git+https://github.com/thu-ml/low-bit-optimizers.git
22
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core
3-
#
4-
# python benchmarks_adam_8bit.py \
3+
#
4+
# python benchmark_low_bit_adam.py \
55
# --model "timm/vit_base_patch16_224.augreg_in21k" \
66
# --amp bf16 \
77
# --optim Adam
8-
#
9-
# To use bnb 8-bit optimizer, set --optim Adam8bitBnb. To use 8-bit optimizer implemented in torchao, set --optim Adam8bitAo
8+
#
9+
# See OPTIM_MAP for the available optimizer options
1010
# To profile and export chrome trace, set --profile
1111
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler
1212

1313
import argparse
14+
import datetime
1415
import math
1516
from contextlib import nullcontext
17+
from functools import partial
1618
from pathlib import Path
1719

1820
import bitsandbytes as bnb
1921
import datasets
22+
import lpmm
2023
import timm
2124
import torch
2225
import torch.nn.functional as F
@@ -25,7 +28,16 @@
2528
from torchvision.transforms import v2
2629
from tqdm import tqdm
2730

28-
from torchao.prototype.optim_8bit import Adam8bit
31+
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
32+
33+
# lpmm doesn't have Adam, only AdamW
34+
OPTIM_MAP = dict(
35+
Adam=torch.optim.Adam,
36+
Adam8bitBnb=bnb.optim.Adam8bit,
37+
Adam8bitAo=Adam8bit,
38+
Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
39+
Adam4bitAo=Adam4bit,
40+
)
2941

3042

3143
class CosineSchedule:
@@ -72,7 +84,7 @@ def get_parser():
7284
parser.add_argument("--batch_size", type=int, default=64)
7385
parser.add_argument("--n_workers", type=int, default=4)
7486

75-
parser.add_argument("--optim", default="Adam")
87+
parser.add_argument("--optim", default="Adam", choices=OPTIM_MAP.keys())
7688
parser.add_argument("--lr", type=float, default=1e-4)
7789
parser.add_argument("--weight_decay", type=float, default=0)
7890
parser.add_argument("--cosine_lr_scheduler", action="store_true")
@@ -159,16 +171,12 @@ def evaluate_model(model, args):
159171
model.compile(fullgraph=True)
160172
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
161173

162-
OPTIM_MAP = dict(
163-
Adam=torch.optim.Adam,
164-
Adam8bitBnb=bnb.optim.Adam8bit,
165-
Adam8bitAo=Adam8bit,
166-
)
167174
optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
168175
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
169176

170177
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
171178

179+
start_time = datetime.datetime.now()
172180
step = 0
173181
for epoch_idx in range(args.n_epochs):
174182
model.train()
@@ -208,4 +216,5 @@ def evaluate_model(model, args):
208216
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
209217
logger.log(dict(val_acc=val_acc), step=step)
210218

211-
print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB")
219+
print(f"Time taken: {(datetime.datetime.now() - start_time)}")
220+
print(f"Max used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

test/prototype/test_low_bit_optim.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import copy
2+
from functools import partial
3+
4+
import pytest
5+
import torch
6+
from torch import nn
7+
from torch.testing._internal.common_utils import (
8+
TestCase,
9+
instantiate_parametrized_tests,
10+
parametrize,
11+
run_tests,
12+
)
13+
from torchao.prototype import low_bit_optim
14+
from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit
15+
from torchao.utils import TORCH_VERSION_AFTER_2_3
16+
17+
try:
18+
import bitsandbytes as bnb
19+
except ImportError:
20+
bnb = None
21+
22+
try:
23+
import lpmm
24+
except ImportError:
25+
lpmm = None
26+
27+
28+
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
29+
30+
31+
class TestQuantize(TestCase):
32+
@parametrize("device", _DEVICES)
33+
def test_quantize_8bit_with_qmap_correctness(self, device):
34+
x = torch.randn(32, 1024, device=device)
35+
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)
36+
37+
actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1)
38+
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0)
39+
40+
torch.testing.assert_close(actual_codes, expected_codes)
41+
torch.testing.assert_close(actual_scale, expected_scale)
42+
43+
@parametrize("device", _DEVICES)
44+
def test_quantize_8bit_with_qmap_compile(self, device):
45+
x = torch.randn(32, 1024, device=device)
46+
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)
47+
48+
compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True)
49+
actual_codes, actual_scale = compiled_f(x, qmap, 256)
50+
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256)
51+
52+
torch.testing.assert_close(actual_codes, expected_codes)
53+
torch.testing.assert_close(actual_scale, expected_scale)
54+
55+
@parametrize("device", _DEVICES)
56+
def test_quantize_4bit_with_qmap_correctness(self, device):
57+
x = torch.randn(32, 1024, device=device)
58+
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)
59+
60+
actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1)
61+
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0)
62+
63+
torch.testing.assert_close(actual_codes, expected_codes)
64+
torch.testing.assert_close(actual_scale, expected_scale)
65+
66+
@parametrize("device", _DEVICES)
67+
def test_quantize_4bit_with_qmap_compile(self, device):
68+
x = torch.randn(32, 1024, device=device)
69+
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)
70+
71+
compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True)
72+
actual_codes, actual_scale = compiled_f(x, qmap, 256)
73+
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256)
74+
75+
torch.testing.assert_close(actual_codes, expected_codes)
76+
torch.testing.assert_close(actual_scale, expected_scale)
77+
78+
79+
class TestOptim(TestCase):
80+
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
81+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
82+
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
83+
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
84+
def test_optim_8bit_correctness(self, optim_name):
85+
device = "cuda"
86+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
87+
model2 = copy.deepcopy(model1)
88+
89+
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
90+
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
91+
92+
for _ in range(2):
93+
x = torch.randn(4, 32, device=device)
94+
95+
loss1 = model1(x).sum()
96+
loss1.backward()
97+
optim1.step()
98+
optim1.zero_grad()
99+
100+
loss2 = model2(x).sum()
101+
loss2.backward()
102+
optim2.step()
103+
optim2.zero_grad()
104+
105+
for p1, p2 in zip(model1.parameters(), model2.parameters()):
106+
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
107+
108+
@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle")
109+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
110+
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
111+
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
112+
def test_optim_4bit_correctness(self, optim_name):
113+
device = "cuda"
114+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
115+
model2 = copy.deepcopy(model1)
116+
117+
# lpmm doesn't have Adam. use AdamW with no weight decay instead.
118+
if optim_name == "Adam4bit":
119+
optim1 = lpmm.optim.AdamW(model1.parameters(), weight_decay=0)
120+
elif optim_name == "AdamW4bit":
121+
optim1 = lpmm.optim.AdamW(model1.parameters())
122+
else:
123+
raise ValueError(f"Unsupported {optim_name} optimizer for lpmm")
124+
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
125+
126+
for _ in range(2):
127+
x = torch.randn(4, 32, device=device)
128+
129+
loss1 = model1(x).sum()
130+
loss1.backward()
131+
optim1.step()
132+
optim1.zero_grad()
133+
134+
loss2 = model2(x).sum()
135+
loss2.backward()
136+
optim2.step()
137+
optim2.zero_grad()
138+
139+
for p1, p2 in zip(model1.parameters(), model2.parameters()):
140+
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
141+
142+
143+
instantiate_parametrized_tests(TestQuantize)
144+
instantiate_parametrized_tests(TestOptim)
145+
146+
147+
if __name__ == "__main__":
148+
run_tests()

test/prototype/test_optim_8bit.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

torchao/prototype/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
- `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm
1111
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
1212
- [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
13-
- [`optim_8bit`](optim_8bit) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
13+
- [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers).
1414

1515
#### Roadmap
1616

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Low-bit optimizers
2+
3+
This folder implements:
4+
5+
- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861
6+
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507
7+
8+
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.
9+
10+
## Usage
11+
12+
This is a drop-in replacement for `torch.optim.Adam`
13+
14+
```python
15+
from torchao.prototype.low_bit_optim import Adam8bit
16+
17+
model = ...
18+
optim = Adam8bit(model.parameters())
19+
```
20+
21+
To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers.
22+
23+
**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand.
24+
25+
NOTE:
26+
- The low-bit optimizers require PyTorch >= 2.3
27+
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
28+
- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.
29+
30+
## Benchmarks
31+
32+
Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).
33+
34+
Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER:
35+
36+
Adam impl | max memory (GB) | time taken | accuracy
37+
-----------|-----------------|------------|----------
38+
PyTorch | 12.98 | 10m 08s | 87.70
39+
bnb 8-bit | 8.31 | 8m 38s | 86.22
40+
ao 8-bit | 8.32 | 10m 54s | 86.67
41+
lpmm 4-bit | 7.72 | 7m 48s | 84.70
42+
ao 4-bit | 7.72 | 9m 17s | 85.60
43+
44+
NOTE: time taken includes validation time, and compile time for torchao optimizers.
45+
46+
## Credits
47+
48+
Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .adam import Adam8bit, Adam4bit
2+
from .adamw import AdamW8bit, AdamW4bit

0 commit comments

Comments
 (0)