19
19
quantize_4bit_with_qmap ,
20
20
_fp32_to_bf16_sr ,
21
21
)
22
- from torchao .utils import TORCH_VERSION_AT_LEAST_2_3 , TORCH_VERSION_AT_LEAST_2_4 , TORCH_VERSION_AT_LEAST_2_6
22
+ from torchao .utils import (
23
+ TORCH_VERSION_AT_LEAST_2_3 ,
24
+ TORCH_VERSION_AT_LEAST_2_4 ,
25
+ TORCH_VERSION_AT_LEAST_2_6 ,
26
+ )
23
27
24
28
try :
25
29
import bitsandbytes as bnb
@@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile):
85
89
x_rep = x .view (- 1 , 1 ).repeat (1 , 100_000 )
86
90
87
91
if compile :
88
- x_rep_bf16 = torch .compile (_fp32_to_bf16_sr , fullgraph = True , dynamic = False )(x_rep )
92
+ x_rep_bf16 = torch .compile (_fp32_to_bf16_sr , fullgraph = True , dynamic = False )(
93
+ x_rep
94
+ )
89
95
else :
90
96
x_rep_bf16 = _fp32_to_bf16_sr (x_rep )
91
97
@@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile):
96
102
97
103
98
104
class TestOptim (TestCase ):
99
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3" )
100
- @parametrize ("optim_name" , ["Adam8bit" , "AdamW8bit" , "Adam4bit" , "AdamW4bit" , "AdamFp8" , "AdamWFp8" ])
105
+ @pytest .mark .skipif (
106
+ not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3"
107
+ )
108
+ @parametrize (
109
+ "optim_name" ,
110
+ ["Adam8bit" , "AdamW8bit" , "Adam4bit" , "AdamW4bit" , "AdamFp8" , "AdamWFp8" ],
111
+ )
101
112
@parametrize ("dtype" , [torch .float32 , torch .bfloat16 ])
102
113
@parametrize ("device" , _DEVICES )
103
114
def test_optim_smoke (self , optim_name , dtype , device ):
@@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device):
141
152
torch .testing .assert_close (p2 , p1 )
142
153
143
154
@pytest .mark .skipif (bnb is None , reason = "bitsandbytes is not available" )
144
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "bitsandbytes 8-bit Adam only works for CUDA" )
145
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3" )
155
+ @pytest .mark .skipif (
156
+ not torch .cuda .is_available (),
157
+ reason = "bitsandbytes 8-bit Adam only works for CUDA" ,
158
+ )
159
+ @pytest .mark .skipif (
160
+ not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3"
161
+ )
146
162
@parametrize ("optim_name" , ["Adam8bit" , "AdamW8bit" ])
147
163
def test_optim_8bit_correctness (self , optim_name ):
148
164
device = "cuda"
149
- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
165
+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
166
+ device
167
+ )
150
168
model2 = copy .deepcopy (model1 )
151
169
152
170
# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
153
171
block_size = 256 if Version (bnb .__version__ ) >= Version ("0.44.0" ) else 2048
154
172
155
173
optim1 = getattr (bnb .optim , optim_name )(model1 .parameters ())
156
- optim2 = getattr (low_bit_optim , optim_name )(model2 .parameters (), block_size = block_size )
174
+ optim2 = getattr (low_bit_optim , optim_name )(
175
+ model2 .parameters (), block_size = block_size
176
+ )
157
177
158
178
for _ in range (2 ):
159
179
x = torch .randn (4 , 32 , device = device )
@@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name):
173
193
174
194
# this will not run in CI because we can't install lpmm
175
195
@pytest .mark .skipif (lpmm is None , reason = "lpmm is not available" )
176
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "lpmm 4-bit Adam only works for CUDA" )
177
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3" )
196
+ @pytest .mark .skipif (
197
+ not torch .cuda .is_available (), reason = "lpmm 4-bit Adam only works for CUDA"
198
+ )
199
+ @pytest .mark .skipif (
200
+ not TORCH_VERSION_AT_LEAST_2_3 , reason = "requires PyTorch >= 2.3"
201
+ )
178
202
@parametrize ("optim_name" , ["Adam4bit" , "AdamW4bit" ])
179
203
def test_optim_4bit_correctness (self , optim_name ):
180
204
device = "cuda"
181
- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
205
+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
206
+ device
207
+ )
182
208
model2 = copy .deepcopy (model1 )
183
209
184
210
# lpmm doesn't have Adam. use AdamW with no weight decay instead.
@@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name):
206
232
for p1 , p2 in zip (model1 .parameters (), model2 .parameters ()):
207
233
torch .testing .assert_close (p2 , p1 , rtol = 1e-5 , atol = 1e-5 )
208
234
209
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "optim CPU offload requires CUDA" )
235
+ @pytest .mark .skipif (
236
+ not torch .cuda .is_available (), reason = "optim CPU offload requires CUDA"
237
+ )
210
238
@parametrize ("offload_grad,grad_accum" , [(False , 1 ), (False , 2 ), (True , 1 )])
211
239
def test_optim_cpu_offload_correctness (self , offload_grad , grad_accum ):
212
240
device = "cuda"
213
- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
214
- model1 [0 ].requires_grad_ (False ) # make sure it can work in the presence of non-trainable params
241
+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
242
+ device
243
+ )
244
+ model1 [0 ].requires_grad_ (
245
+ False
246
+ ) # make sure it can work in the presence of non-trainable params
215
247
model2 = copy .deepcopy (model1 )
216
248
217
249
optim1 = torch .optim .AdamW (model1 .parameters ())
218
250
optim2 = low_bit_optim .CPUOffloadOptimizer (
219
- model2 .parameters (), torch .optim .AdamW , offload_gradients = offload_grad ,
251
+ model2 .parameters (),
252
+ torch .optim .AdamW ,
253
+ offload_gradients = offload_grad ,
220
254
)
221
255
222
256
for _ in range (2 ):
@@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
234
268
for p1 , p2 in zip (model1 .parameters (), model2 .parameters ()):
235
269
torch .testing .assert_close (p2 , p1 )
236
270
237
- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "optim CPU offload requires CUDA" )
271
+ @pytest .mark .skipif (
272
+ not torch .cuda .is_available (), reason = "optim CPU offload requires CUDA"
273
+ )
238
274
def test_optim_cpu_offload_save_load (self ):
239
275
device = "cuda"
240
- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
241
- optim1 = low_bit_optim .CPUOffloadOptimizer (model1 .parameters (), torch .optim .AdamW )
276
+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
277
+ device
278
+ )
279
+ optim1 = low_bit_optim .CPUOffloadOptimizer (
280
+ model1 .parameters (), torch .optim .AdamW
281
+ )
242
282
243
283
for _ in range (2 ):
244
284
x = torch .randn (4 , 32 , device = device )
@@ -253,7 +293,9 @@ def test_optim_cpu_offload_save_load(self):
253
293
254
294
# resume training
255
295
model2 = copy .deepcopy (model1 )
256
- optim2 = low_bit_optim .CPUOffloadOptimizer (model2 .parameters (), torch .optim .AdamW )
296
+ optim2 = low_bit_optim .CPUOffloadOptimizer (
297
+ model2 .parameters (), torch .optim .AdamW
298
+ )
257
299
optim2 .load_state_dict (state_dict )
258
300
259
301
for _ in range (2 ):
@@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self):
273
315
def test_optim_bf16_stochastic_round_correctness (self ):
274
316
device = "cuda" if torch .cuda .is_available () else "cpu"
275
317
torch .manual_seed (2024 )
276
- model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (device )
318
+ model1 = nn .Sequential (nn .Linear (32 , 1024 ), nn .ReLU (), nn .Linear (1024 , 128 )).to (
319
+ device
320
+ )
277
321
model2 = copy .deepcopy (model1 ).bfloat16 ()
278
322
279
323
# small LR so that weight update is small
280
324
# when bf16_stochastic_round=False, the test will fail after 1 iteration
281
325
optim1 = torch .optim .AdamW (model1 .parameters (), lr = 1e-5 )
282
- optim2 = low_bit_optim ._AdamW (model2 .parameters (), lr = 1e-5 , bf16_stochastic_round = True )
326
+ optim2 = low_bit_optim ._AdamW (
327
+ model2 .parameters (), lr = 1e-5 , bf16_stochastic_round = True
328
+ )
283
329
284
330
# overfit on this sample
285
331
x = torch .randn (4 , 32 , device = device )
@@ -299,15 +345,19 @@ def test_optim_bf16_stochastic_round_correctness(self):
299
345
optim2 .step ()
300
346
optim2 .zero_grad ()
301
347
302
- torch .testing .assert_close (loss1 , loss2 , msg = lambda msg : f"Iteration { idx } . { msg } " )
348
+ torch .testing .assert_close (
349
+ loss1 , loss2 , msg = lambda msg : f"Iteration { idx } . { msg } "
350
+ )
303
351
304
352
305
353
class TestFSDP2 (FSDPTest ):
306
354
@property
307
355
def world_size (self ) -> int :
308
356
return 2
309
357
310
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_6 , reason = "PyTorch>=2.6 is required." )
358
+ @pytest .mark .skipif (
359
+ not TORCH_VERSION_AT_LEAST_2_6 , reason = "PyTorch>=2.6 is required."
360
+ )
311
361
@skip_if_lt_x_gpu (2 )
312
362
def test_fsdp2 (self ):
313
363
optim_classes = [low_bit_optim .AdamW8bit , low_bit_optim .AdamW4bit ]
@@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls):
363
413
base_loss .backward ()
364
414
for param in base_model .parameters ():
365
415
if param .grad is not None :
366
- torch .distributed .all_reduce (param .grad , op = torch .distributed .ReduceOp .AVG )
416
+ torch .distributed .all_reduce (
417
+ param .grad , op = torch .distributed .ReduceOp .AVG
418
+ )
367
419
base_optim .step ()
368
420
self .assertEqual (fsdp_loss , base_loss )
369
421
0 commit comments