18
18
fake_quantize_per_channel_group ,
19
19
fake_quantize_per_token ,
20
20
)
21
- from torchao .quantization .utils import get_group_qparams_symmetric
21
+ from torchao .quantization .utils import (
22
+ get_group_qparams_symmetric ,
23
+ get_groupwise_affine_qparams ,
24
+ groupwise_affine_dequantize_tensor_from_qparams ,
25
+ groupwise_affine_quantize_tensor ,
26
+ groupwise_affine_quantize_tensor_from_qparams ,
27
+ )
22
28
from torchao .utils import TORCH_VERSION_AFTER_2_4
23
29
24
30
25
31
# TODO: put this in a common test utils file
32
+ _CUDA_IS_AVAILABLE = torch .cuda .is_available ()
33
+
26
34
class Sub (torch .nn .Module ):
27
35
def __init__ (self ):
28
36
super ().__init__ ()
29
- self .linear = torch .nn .Linear (32 , 32 , bias = False ).to (torch .float )
37
+ self .linear = torch .nn .Linear (256 , 256 , bias = False ).to (torch .float )
30
38
31
39
def example_inputs (self ):
32
- return (torch .randn (1 , 32 ).to (torch .float ),)
40
+ return (torch .randn (1 , 256 ).to (torch .float ),)
33
41
34
42
def forward (self , x ):
35
43
return self .linear (x )
36
44
37
45
class M (torch .nn .Module ):
38
46
def __init__ (self ):
39
47
super ().__init__ ()
40
- self .linear1 = torch .nn .Linear (64 , 32 , bias = False ).to (torch .float )
48
+ self .linear1 = torch .nn .Linear (512 , 256 , bias = False ).to (torch .float )
41
49
self .sub = Sub ()
42
- self .linear2 = torch .nn .Linear (32 , 64 , bias = False ).to (torch .float )
50
+ self .linear2 = torch .nn .Linear (256 , 512 , bias = False ).to (torch .float )
43
51
44
52
def example_inputs (self ):
45
- return (torch .randn (1 , 64 ).to (torch .float ),)
53
+ return (torch .randn (1 , 512 ).to (torch .float ),)
46
54
47
55
def forward (self , x ):
48
56
x = self .linear1 (x )
@@ -111,23 +119,46 @@ def test_fake_quantize_per_token(self):
111
119
112
120
def _set_ptq_weight (
113
121
self ,
114
- ptq_linear : "Int8DynActInt4WeightLinear" ,
115
- fp32_weight : torch .Tensor ,
116
- group_size : int ,
122
+ ptq_linear : torch .nn .Module ,
123
+ qat_linear : torch .nn .Module ,
117
124
):
118
125
"""
119
126
Set the weight to the quantized version of the given fp32 weights,
120
127
for making linear outputs comparable with QAT.
121
128
"""
129
+ from torchao .quantization .GPTQ import (
130
+ Int8DynActInt4WeightLinear ,
131
+ WeightOnlyInt4Linear ,
132
+ )
133
+ from torchao .quantization .prototype .qat import (
134
+ Int8DynActInt4WeightQATLinear ,
135
+ Int4WeightOnlyQATLinear ,
136
+ )
122
137
n_bit = 4
123
138
(qmin , qmax ) = self ._get_qmin_qmax (n_bit )
124
- (s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
125
- q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
126
- fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
127
- )
128
- ptq_linear .weight = q_weight
129
- ptq_linear .scales = s
130
- ptq_linear .zeros = zp
139
+ if isinstance (ptq_linear , Int8DynActInt4WeightLinear ):
140
+ assert isinstance (qat_linear , Int8DynActInt4WeightQATLinear )
141
+ fp32_weight = qat_linear .weight
142
+ group_size = qat_linear .groupsize
143
+ (s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
144
+ q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
145
+ fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
146
+ )
147
+ ptq_linear .weight = q_weight
148
+ ptq_linear .scales = s
149
+ ptq_linear .zeros = zp
150
+ elif isinstance (ptq_linear , WeightOnlyInt4Linear ):
151
+ assert isinstance (qat_linear , Int4WeightOnlyQATLinear )
152
+ (q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
153
+ qat_linear .weight , n_bit , qat_linear .groupsize ,
154
+ )
155
+ q_weight = torch .ops .aten ._convert_weight_to_int4pack (
156
+ q_weight .to ("cuda" ), qat_linear .inner_k_tiles ,
157
+ )
158
+ ptq_linear .weight = q_weight
159
+ ptq_linear .scales_and_zeros = scales_and_zeros
160
+ else :
161
+ raise ValueError ("Unknown ptq_linear type: %s" % type (ptq_linear ))
131
162
132
163
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
133
164
def test_qat_8da4w_linear (self ):
@@ -144,7 +175,7 @@ def test_qat_8da4w_linear(self):
144
175
)
145
176
146
177
# Force the weights to be the same
147
- self ._set_ptq_weight (ptq_linear , qat_linear . weight , group_size )
178
+ self ._set_ptq_weight (ptq_linear , qat_linear )
148
179
149
180
# Compare linear values
150
181
torch .manual_seed (self .SEED )
@@ -280,7 +311,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
280
311
loss_fn1 = torch .nn .CrossEntropyLoss ()
281
312
loss_fn2 = torch .nn .CrossEntropyLoss ()
282
313
example_inputs = nn_model .example_inputs ()
283
- target = torch .randn (1 , 64 ).float ()
314
+ target = torch .randn (1 , 512 ).float ()
284
315
output1 = nn_model (* example_inputs )
285
316
output2 = qat_model (* example_inputs )
286
317
torch .testing .assert_close (output1 , output2 , atol = 0 , rtol = 0 )
@@ -322,6 +353,123 @@ def test_qat_generic_fake_quantize(self):
322
353
torch .testing .assert_close (py_out , ao_out , atol = 0 , rtol = 0 )
323
354
torch .testing .assert_close (py_input .grad , ao_input .grad , atol = 0 , rtol = 0 )
324
355
356
+ def _assert_close_4w (self , val , ref ):
357
+ # Note: for int4 weight-only quantization, we do not expect exact match
358
+ # because torch._weight_int4pack_mm and torch.mm do not match exactly.
359
+ # Here we use the same error bar as PyTorch core to determine closeness:
360
+ # https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079
361
+ mean_err = ((val - ref ) / ref ).mean ().abs ()
362
+ print (mean_err )
363
+ self .assertTrue (mean_err < 0.05 )
364
+
365
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
366
+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
367
+ def test_qat_4w_primitives (self ):
368
+ n_bit = 4
369
+ group_size = 32
370
+ inner_k_tiles = 8
371
+ scales_precision = torch .bfloat16
372
+ device = torch .device ("cuda" )
373
+ dtype = torch .bfloat16
374
+ torch .manual_seed (self .SEED )
375
+ x = torch .randn (100 , 256 , dtype = dtype , device = device )
376
+ weight = torch .randn (512 , 256 , dtype = dtype , device = device )
377
+
378
+ # PTQ
379
+ (q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
380
+ weight , n_bit , group_size , scales_precision ,
381
+ )
382
+ q_weight = torch .ops .aten ._convert_weight_to_int4pack (
383
+ q_weight .to (device ), inner_k_tiles ,
384
+ )
385
+ ptq_out = torch .ops .aten ._weight_int4pack_mm (
386
+ x , q_weight , group_size , scales_and_zeros
387
+ )
388
+
389
+ # QAT
390
+ scales , zero_points = get_groupwise_affine_qparams (
391
+ weight , n_bit , group_size , scales_precision ,
392
+ )
393
+ w_q = groupwise_affine_quantize_tensor_from_qparams (
394
+ weight , scales , zero_points , n_bit , group_size , cast_dtypes = False ,
395
+ )
396
+ w_dq = groupwise_affine_dequantize_tensor_from_qparams (
397
+ w_q , scales , zero_points , n_bit , group_size , cast_dtypes = False ,
398
+ )
399
+ qat_out = torch .nn .functional .linear (x , w_dq )
400
+
401
+ self ._assert_close_4w (qat_out , ptq_out )
402
+
403
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
404
+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
405
+ def test_qat_4w_linear (self ):
406
+ from torchao .quantization .prototype .qat import Int4WeightOnlyQATLinear
407
+ from torchao .quantization .GPTQ import WeightOnlyInt4Linear
408
+
409
+ group_size = 128
410
+ device = torch .device ("cuda" )
411
+ dtype = torch .bfloat16
412
+ torch .manual_seed (self .SEED )
413
+ qat_linear = Int4WeightOnlyQATLinear (
414
+ 256 , 688 , bias = False , groupsize = group_size , device = device ,
415
+ )
416
+ ptq_linear = WeightOnlyInt4Linear (
417
+ 256 , 688 , bias = False , groupsize = group_size , device = device ,
418
+ )
419
+
420
+ # Force the weights to be the same
421
+ self ._set_ptq_weight (ptq_linear , qat_linear )
422
+
423
+ # Compare linear values
424
+ torch .manual_seed (self .SEED )
425
+ x = torch .randn (100 , 256 , dtype = dtype , device = device )
426
+ x2 = copy .deepcopy (x )
427
+ qat_out = qat_linear (x )
428
+ ptq_out = ptq_linear (x2 )
429
+ self ._assert_close_4w (qat_out , ptq_out )
430
+
431
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
432
+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
433
+ def test_qat_4w_quantizer (self ):
434
+ from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
435
+ from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
436
+
437
+ group_size = 32
438
+ inner_k_tiles = 8
439
+ device = torch .device ("cuda" )
440
+ dtype = torch .bfloat16
441
+ torch .manual_seed (self .SEED )
442
+ m = M ().to (device ).to (dtype )
443
+ m2 = copy .deepcopy (m )
444
+ qat_quantizer = Int4WeightOnlyQATQuantizer (
445
+ groupsize = group_size , inner_k_tiles = inner_k_tiles ,
446
+ )
447
+ ptq_quantizer = Int4WeightOnlyQuantizer (
448
+ groupsize = group_size , inner_k_tiles = inner_k_tiles ,
449
+ )
450
+ qat_model = qat_quantizer .prepare (m )
451
+ ptq_model = ptq_quantizer .quantize (m2 )
452
+
453
+ # Compare model values
454
+ torch .manual_seed (self .SEED )
455
+ x = [i .to (device ).to (dtype ) for i in m .example_inputs ()]
456
+ x2 = copy .deepcopy (x )
457
+ qat_out = qat_model (* x )
458
+ ptq_out = ptq_model (* x2 )
459
+ self ._assert_close_4w (qat_out , ptq_out )
460
+
461
+ # Convert QAT model and compare model values
462
+ converted_model = qat_quantizer .convert (qat_model )
463
+ converted_out = converted_model (* x )
464
+ torch .testing .assert_close (converted_out , ptq_out , atol = 0 , rtol = 0 )
465
+
466
+ # Compare converted state dict
467
+ ptq_state_dict = ptq_model .state_dict ()
468
+ converted_state_dict = converted_model .state_dict ()
469
+ self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
470
+ for k in ptq_state_dict .keys ():
471
+ torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
472
+
325
473
326
474
if __name__ == "__main__" :
327
475
unittest .main ()
0 commit comments