Skip to content

Commit f23e5a3

Browse files
committed
Fix for weights-only load
stack-info: PR: #1228, branch: drisspg/stack/19
1 parent 6fd77d5 commit f23e5a3

File tree

9 files changed

+218
-68
lines changed

9 files changed

+218
-68
lines changed

ruff.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ include = [
1111
"torchao/quantization/linear_activation_weight_observer.py",
1212
"test/quantization/test_observer.py",
1313
"test/dtypes/test_affine_quantized_float.py",
14-
"torchao/quantization/weight_tensor_linear_activation_quantization.py"
14+
"torchao/quantization/weight_tensor_linear_activation_quantization.py",
15+
"torchao/prototype/low_bit_optim/**.py",
16+
"test/prototype/low_bit_optim/**.py",
1517

1618
]

test/prototype/test_low_bit_optim.py

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
quantize_4bit_with_qmap,
2020
_fp32_to_bf16_sr,
2121
)
22-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
22+
from torchao.utils import (
23+
TORCH_VERSION_AT_LEAST_2_3,
24+
TORCH_VERSION_AT_LEAST_2_4,
25+
TORCH_VERSION_AT_LEAST_2_6,
26+
)
2327

2428
try:
2529
import bitsandbytes as bnb
@@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile):
8589
x_rep = x.view(-1, 1).repeat(1, 100_000)
8690

8791
if compile:
88-
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep)
92+
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(
93+
x_rep
94+
)
8995
else:
9096
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)
9197

@@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile):
96102

97103

98104
class TestOptim(TestCase):
99-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
100-
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
105+
@pytest.mark.skipif(
106+
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
107+
)
108+
@parametrize(
109+
"optim_name",
110+
["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"],
111+
)
101112
@parametrize("dtype", [torch.float32, torch.bfloat16])
102113
@parametrize("device", _DEVICES)
103114
def test_optim_smoke(self, optim_name, dtype, device):
@@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device):
141152
torch.testing.assert_close(p2, p1)
142153

143154
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")
144-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
145-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
155+
@pytest.mark.skipif(
156+
not torch.cuda.is_available(),
157+
reason="bitsandbytes 8-bit Adam only works for CUDA",
158+
)
159+
@pytest.mark.skipif(
160+
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
161+
)
146162
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
147163
def test_optim_8bit_correctness(self, optim_name):
148164
device = "cuda"
149-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
165+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
166+
device
167+
)
150168
model2 = copy.deepcopy(model1)
151169

152170
# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
153171
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048
154172

155173
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
156-
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
174+
optim2 = getattr(low_bit_optim, optim_name)(
175+
model2.parameters(), block_size=block_size
176+
)
157177

158178
for _ in range(2):
159179
x = torch.randn(4, 32, device=device)
@@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name):
173193

174194
# this will not run in CI because we can't install lpmm
175195
@pytest.mark.skipif(lpmm is None, reason="lpmm is not available")
176-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
177-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
196+
@pytest.mark.skipif(
197+
not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA"
198+
)
199+
@pytest.mark.skipif(
200+
not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3"
201+
)
178202
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
179203
def test_optim_4bit_correctness(self, optim_name):
180204
device = "cuda"
181-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
205+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
206+
device
207+
)
182208
model2 = copy.deepcopy(model1)
183209

184210
# lpmm doesn't have Adam. use AdamW with no weight decay instead.
@@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name):
206232
for p1, p2 in zip(model1.parameters(), model2.parameters()):
207233
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
208234

209-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
235+
@pytest.mark.skipif(
236+
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
237+
)
210238
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
211239
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
212240
device = "cuda"
213-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
214-
model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params
241+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
242+
device
243+
)
244+
model1[0].requires_grad_(
245+
False
246+
) # make sure it can work in the presence of non-trainable params
215247
model2 = copy.deepcopy(model1)
216248

217249
optim1 = torch.optim.AdamW(model1.parameters())
218250
optim2 = low_bit_optim.CPUOffloadOptimizer(
219-
model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad,
251+
model2.parameters(),
252+
torch.optim.AdamW,
253+
offload_gradients=offload_grad,
220254
)
221255

222256
for _ in range(2):
@@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
234268
for p1, p2 in zip(model1.parameters(), model2.parameters()):
235269
torch.testing.assert_close(p2, p1)
236270

237-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
271+
@pytest.mark.skipif(
272+
not torch.cuda.is_available(), reason="optim CPU offload requires CUDA"
273+
)
238274
def test_optim_cpu_offload_save_load(self):
239275
device = "cuda"
240-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
241-
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)
276+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
277+
device
278+
)
279+
optim1 = low_bit_optim.CPUOffloadOptimizer(
280+
model1.parameters(), torch.optim.AdamW
281+
)
242282

243283
for _ in range(2):
244284
x = torch.randn(4, 32, device=device)
@@ -253,7 +293,9 @@ def test_optim_cpu_offload_save_load(self):
253293

254294
# resume training
255295
model2 = copy.deepcopy(model1)
256-
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
296+
optim2 = low_bit_optim.CPUOffloadOptimizer(
297+
model2.parameters(), torch.optim.AdamW
298+
)
257299
optim2.load_state_dict(state_dict)
258300

259301
for _ in range(2):
@@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self):
273315
def test_optim_bf16_stochastic_round_correctness(self):
274316
device = "cuda" if torch.cuda.is_available() else "cpu"
275317
torch.manual_seed(2024)
276-
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
318+
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(
319+
device
320+
)
277321
model2 = copy.deepcopy(model1).bfloat16()
278322

279323
# small LR so that weight update is small
280324
# when bf16_stochastic_round=False, the test will fail after 1 iteration
281325
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
282-
optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True)
326+
optim2 = low_bit_optim._AdamW(
327+
model2.parameters(), lr=1e-5, bf16_stochastic_round=True
328+
)
283329

284330
# overfit on this sample
285331
x = torch.randn(4, 32, device=device)
@@ -299,15 +345,19 @@ def test_optim_bf16_stochastic_round_correctness(self):
299345
optim2.step()
300346
optim2.zero_grad()
301347

302-
torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}")
348+
torch.testing.assert_close(
349+
loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}"
350+
)
303351

304352

305353
class TestFSDP2(FSDPTest):
306354
@property
307355
def world_size(self) -> int:
308356
return 2
309357

310-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.")
358+
@pytest.mark.skipif(
359+
not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required."
360+
)
311361
@skip_if_lt_x_gpu(2)
312362
def test_fsdp2(self):
313363
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
@@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls):
363413
base_loss.backward()
364414
for param in base_model.parameters():
365415
if param.grad is not None:
366-
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG)
416+
torch.distributed.all_reduce(
417+
param.grad, op=torch.distributed.ReduceOp.AVG
418+
)
367419
base_optim.step()
368420
self.assertEqual(fsdp_loss, base_loss)
369421

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,13 @@
11
from .adam import Adam4bit, Adam8bit, AdamFp8, AdamW4bit, AdamW8bit, AdamWFp8, _AdamW
22
from .cpu_offload import CPUOffloadOptimizer
3+
4+
__all__ = [
5+
"Adam4bit",
6+
"Adam8bit",
7+
"AdamFp8",
8+
"AdamW4bit",
9+
"AdamW8bit",
10+
"AdamWFp8",
11+
"_AdamW",
12+
"CPUOffloadOptimizer",
13+
]

torchao/prototype/low_bit_optim/adam.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,28 @@
22

33
import torch
44
from torch import Tensor
5-
from torch.optim import Optimizer
65
from torch.distributed._tensor import DTensor
6+
from torch.optim import Optimizer
77

8-
from .subclass_8bit import OptimState8bit
8+
from .quant_utils import _fp32_to_bf16_sr
99
from .subclass_4bit import OptimState4bit
10+
from .subclass_8bit import OptimState8bit
1011
from .subclass_fp8 import OptimStateFp8
11-
from .quant_utils import _fp32_to_bf16_sr
1212

1313

1414
class _AdamBase(Optimizer):
1515
def __init__(
16-
self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, bf16_stochastic_round, is_adamw
16+
self,
17+
params,
18+
lr,
19+
betas,
20+
eps,
21+
weight_decay,
22+
amsgrad,
23+
*,
24+
block_size,
25+
bf16_stochastic_round,
26+
is_adamw,
1727
) -> None:
1828
if not 0.0 <= lr:
1929
raise ValueError("Invalid learning rate: {}".format(lr))
@@ -23,7 +33,13 @@ def __init__(
2333
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
2434
if not 0.0 <= betas[1] < 1.0:
2535
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
26-
defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
36+
defaults = dict(
37+
lr=torch.tensor(lr),
38+
betas=betas,
39+
eps=eps,
40+
weight_decay=weight_decay,
41+
amsgrad=amsgrad,
42+
)
2743
super().__init__(params, defaults)
2844
self.block_size = block_size
2945
self.bf16_stochastic_round = bf16_stochastic_round
@@ -45,7 +61,9 @@ def _new_buffer(self, p: Tensor, signed: bool):
4561
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
4662
if isinstance(p, DTensor):
4763
out = DTensor.from_local(
48-
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
64+
local_tensor=self._subclass_zeros(
65+
p.to_local(), signed, self.block_size
66+
),
4967
device_mesh=p.device_mesh,
5068
placements=p.placements,
5169
run_check=False,

torchao/prototype/low_bit_optim/cpu_offload.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ def __init__(
2525
kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
2626
"""
2727
# default to fused CPU AdamW
28-
if optimizer_class is torch.optim.AdamW and TORCH_VERSION_AT_LEAST_2_4 and "fused" not in kwargs:
28+
if (
29+
optimizer_class is torch.optim.AdamW
30+
and TORCH_VERSION_AT_LEAST_2_4
31+
and "fused" not in kwargs
32+
):
2933
kwargs.update(fused=True)
3034

3135
param_groups = list(params)
@@ -77,7 +81,9 @@ def backward_hook(p_cuda):
7781
self.param_cuda2cpu_map[p_cuda] = p_cpu
7882

7983
p_cuda.register_post_accumulate_grad_hook(backward_hook)
80-
self.optim_dict[p_cuda] = optimizer_class([{"params": p_cpu, **param_group}], **kwargs)
84+
self.optim_dict[p_cuda] = optimizer_class(
85+
[{"params": p_cpu, **param_group}], **kwargs
86+
)
8187

8288
@torch.no_grad()
8389
def step(self, closure=None):

torchao/prototype/low_bit_optim/quant_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,17 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
122122
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
123123
#
124124
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
125-
rand_16bit = torch.randint(0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32)
125+
rand_16bit = torch.randint(
126+
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
127+
)
126128
x_f32_bits = x_f32.view(torch.int32)
127-
x_fraction = x_f32_bits & 0xFFFF # lower 16 bits
128-
x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits
129+
x_fraction = x_f32_bits & 0xFFFF # lower 16 bits
130+
x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits
129131

130132
x_f32_bits = torch.where(
131-
rand_16bit < x_fraction, # this is True with the probability of p_fraction
132-
x_bf16_towards_zero + 0x10000, # this might overflow, which will result in UB due to signed integer
133+
rand_16bit < x_fraction, # this is True with the probability of p_fraction
134+
x_bf16_towards_zero
135+
+ 0x10000, # this might overflow, which will result in UB due to signed integer
133136
x_bf16_towards_zero,
134137
)
135138
# alternative, slightly faster

0 commit comments

Comments
 (0)