9
9
import unittest
10
10
import torch
11
11
import os
12
- from torch ._export import capture_pre_autograd_graph
13
12
from torch .ao .quantization .quantize_pt2e import (
14
13
prepare_pt2e ,
15
14
convert_pt2e ,
36
35
37
36
38
37
def dynamic_quant (model , example_inputs ):
39
- m = capture_pre_autograd_graph (model , example_inputs )
38
+ m = torch . export . export (model , example_inputs ). module ( )
40
39
quantizer = XNNPACKQuantizer ().set_global (get_symmetric_quantization_config (is_dynamic = True ))
41
40
m = prepare_pt2e (m , quantizer )
42
41
m = convert_pt2e (m )
@@ -50,14 +49,14 @@ def _apply_dynamic_quant(model):
50
49
"""
51
50
_replace_with_custom_fn_if_matches_filter (
52
51
model ,
53
- lambda linear_mod : dynamic_quant (linear_mod , (torch .randn (1 , linear_mod .in_features ))),
52
+ lambda linear_mod : dynamic_quant (linear_mod , (torch .randn (1 , linear_mod .in_features ), )),
54
53
lambda mod , fqn : isinstance (mod , torch .nn .Linear ),
55
54
)
56
55
return model
57
56
58
57
59
58
def capture_and_prepare (model , example_inputs ):
60
- m = capture_pre_autograd_graph (model , example_inputs )
59
+ m = torch . export . export (model , example_inputs )
61
60
quantizer = XNNPACKQuantizer ().set_global (get_symmetric_quantization_config (is_dynamic = True ))
62
61
m = prepare_pt2e (m , quantizer )
63
62
# TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
@@ -88,13 +87,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
88
87
return model
89
88
90
89
class ToyLinearModel (torch .nn .Module ):
91
- def __init__ (self ):
90
+ def __init__ (self , m = 64 , n = 32 , k = 64 ):
92
91
super ().__init__ ()
93
- self .linear1 = torch .nn .Linear (64 , 32 , bias = False ).to (torch .float )
94
- self .linear2 = torch .nn .Linear (32 , 64 , bias = False ).to (torch .float )
92
+ self .linear1 = torch .nn .Linear (m , n , bias = False ).to (torch .float )
93
+ self .linear2 = torch .nn .Linear (n , k , bias = False ).to (torch .float )
95
94
96
95
def example_inputs (self ):
97
- return (torch .randn (1 , 64 ).to (torch .float ),)
96
+ return (torch .randn (1 , self . linear1 . in_features ).to (torch .float ),)
98
97
99
98
def forward (self , x ):
100
99
x = self .linear1 (x )
@@ -104,8 +103,9 @@ def forward(self, x):
104
103
class TestQuantFlow (unittest .TestCase ):
105
104
def test_dynamic_quant_gpu_singleline (self ):
106
105
m = ToyLinearModel ().eval ()
106
+ example_inputs = m .example_inputs ()
107
107
m = _apply_dynamic_quant (m )
108
- quantized = m (* m . example_inputs () )
108
+ quantized = m (* example_inputs )
109
109
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
110
110
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
111
111
# m = torch.compile(m, mode="max-autotune")
@@ -442,7 +442,94 @@ def get_per_token_block_size(x):
442
442
ref = m_copy (* example_inputs )
443
443
self .assertTrue (torch .equal (res , ref ))
444
444
445
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
446
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
447
+ def test_quantized_tensor_subclass_int4 (self ):
448
+ from torchao .quantization .subclass import AffineQuantizedTensor
449
+ from torchao .quantization .quant_primitives import MappingType
450
+ from torchao .quantization .quant_primitives import ZeroPointDomain
451
+ import copy
452
+
453
+ # weight settings
454
+ groupsize = 32
455
+ mapping_type = MappingType .ASYMMETRIC
456
+ block_size = (1 , groupsize )
457
+ target_dtype = torch .int32
458
+ quant_min = 0
459
+ quant_max = 15
460
+ eps = 1e-6
461
+ preserve_zero = False
462
+ zero_point_dtype = torch .bfloat16
463
+
464
+ # weight only quantization
465
+ input_quant_func = None
466
+
467
+ # use 1024 so that we don't need padding
468
+ m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
469
+ m_copy = copy .deepcopy (m )
470
+ example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ).to ("cuda" ), m .example_inputs ()))
471
+
472
+ def to_quantized (weight ):
473
+ return AffineQuantizedTensor .from_float (
474
+ weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps ,
475
+ zero_point_dtype = zero_point_dtype ,
476
+ preserve_zero = preserve_zero ,
477
+ zero_point_domain = ZeroPointDomain .FLOAT ,
478
+ input_quant_func = input_quant_func ,
479
+ )
480
+
481
+ m .linear1 .weight = torch .nn .Parameter (to_quantized (m .linear1 .weight ), requires_grad = False )
482
+ m .linear2 .weight = torch .nn .Parameter (to_quantized (m .linear2 .weight ), requires_grad = False )
483
+ assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
484
+ assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
485
+
486
+ # reference
487
+ from torchao .quantization .quant_api import change_linear_weights_to_int4_woqtensors
488
+ change_linear_weights_to_int4_woqtensors (m_copy , groupsize = groupsize )
489
+
490
+ res = m (* example_inputs )
491
+ ref = m_copy (* example_inputs )
492
+
493
+ self .assertTrue (torch .equal (res , ref ))
494
+
495
+
496
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
497
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
498
+ def test_quantized_tensor_subclass_int8 (self ):
499
+ from torchao .quantization .subclass import AffineQuantizedTensor
500
+ from torchao .quantization .quant_primitives import MappingType
501
+ import copy
502
+
503
+ # weight settings
504
+ mapping_type = MappingType .SYMMETRIC
505
+ target_dtype = torch .int8
506
+ eps = torch .finfo (torch .float32 ).eps
507
+ zero_point_dtype = torch .int64
508
+
509
+ # weight only quantization
510
+ input_quant_func = None
511
+
512
+ m = ToyLinearModel ().eval ().to (torch .bfloat16 )
513
+ m_copy = copy .deepcopy (m )
514
+ example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ), m .example_inputs ()))
515
+
516
+ def to_quantized (weight ):
517
+ block_size = (1 , weight .shape [1 ])
518
+ return AffineQuantizedTensor .from_float (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype , input_quant_func = input_quant_func )
519
+
520
+ m .linear1 .weight = torch .nn .Parameter (to_quantized (m .linear1 .weight ), requires_grad = False )
521
+ m .linear2 .weight = torch .nn .Parameter (to_quantized (m .linear2 .weight ), requires_grad = False )
522
+ assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
523
+ assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
524
+
525
+ # reference
526
+ from torchao .quantization .quant_api import change_linear_weights_to_int8_woqtensors
527
+ change_linear_weights_to_int8_woqtensors (m_copy )
528
+
529
+ res = m (* example_inputs )
530
+ ref = m_copy (* example_inputs )
445
531
532
+ torch .testing .assert_close (res , ref , rtol = 0.00001 , atol = 1e-2 )
446
533
447
534
448
535
if __name__ == "__main__" :
0 commit comments