9
9
import unittest
10
10
import torch
11
11
import os
12
- from torch ._export import capture_pre_autograd_graph
13
12
from torch .ao .quantization .quantize_pt2e import (
14
13
prepare_pt2e ,
15
14
convert_pt2e ,
36
35
37
36
38
37
def dynamic_quant (model , example_inputs ):
39
- m = capture_pre_autograd_graph (model , example_inputs )
38
+ m = torch . export . export (model , example_inputs ). module ( )
40
39
quantizer = XNNPACKQuantizer ().set_global (get_symmetric_quantization_config (is_dynamic = True ))
41
40
m = prepare_pt2e (m , quantizer )
42
41
m = convert_pt2e (m )
@@ -50,14 +49,14 @@ def _apply_dynamic_quant(model):
50
49
"""
51
50
_replace_with_custom_fn_if_matches_filter (
52
51
model ,
53
- lambda linear_mod : dynamic_quant (linear_mod , (torch .randn (1 , linear_mod .in_features ))),
52
+ lambda linear_mod : dynamic_quant (linear_mod , (torch .randn (1 , linear_mod .in_features ), )),
54
53
lambda mod , fqn : isinstance (mod , torch .nn .Linear ),
55
54
)
56
55
return model
57
56
58
57
59
58
def capture_and_prepare (model , example_inputs ):
60
- m = capture_pre_autograd_graph (model , example_inputs )
59
+ m = torch . export . export (model , example_inputs )
61
60
quantizer = XNNPACKQuantizer ().set_global (get_symmetric_quantization_config (is_dynamic = True ))
62
61
m = prepare_pt2e (m , quantizer )
63
62
# TODO: we can run the weight observer in convert_pt2e so that user don't need to run this
@@ -88,13 +87,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
88
87
return model
89
88
90
89
class ToyLinearModel (torch .nn .Module ):
91
- def __init__ (self ):
90
+ def __init__ (self , m = 64 , n = 32 , k = 64 ):
92
91
super ().__init__ ()
93
- self .linear1 = torch .nn .Linear (64 , 32 , bias = False ).to (torch .float )
94
- self .linear2 = torch .nn .Linear (32 , 64 , bias = False ).to (torch .float )
92
+ self .linear1 = torch .nn .Linear (m , n , bias = False ).to (torch .float )
93
+ self .linear2 = torch .nn .Linear (n , k , bias = False ).to (torch .float )
95
94
96
95
def example_inputs (self ):
97
- return (torch .randn (1 , 64 ).to (torch .float ),)
96
+ return (torch .randn (1 , self . linear1 . in_features ).to (torch .float ),)
98
97
99
98
def forward (self , x ):
100
99
x = self .linear1 (x )
@@ -104,8 +103,9 @@ def forward(self, x):
104
103
class TestQuantFlow (unittest .TestCase ):
105
104
def test_dynamic_quant_gpu_singleline (self ):
106
105
m = ToyLinearModel ().eval ()
106
+ example_inputs = m .example_inputs ()
107
107
m = _apply_dynamic_quant (m )
108
- quantized = m (* m . example_inputs () )
108
+ quantized = m (* example_inputs )
109
109
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
110
110
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
111
111
# m = torch.compile(m, mode="max-autotune")
@@ -442,7 +442,92 @@ def get_per_token_block_size(x):
442
442
ref = m_copy (* example_inputs )
443
443
self .assertTrue (torch .equal (res , ref ))
444
444
445
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
446
+ def test_quantized_tensor_subclass_int4 (self ):
447
+ from torchao .quantization .subclass import AffineQuantizedTensor
448
+ from torchao .quantization .quant_primitives import MappingType
449
+ from torchao .quantization .quant_primitives import ZeroPointDomain
450
+ import copy
451
+
452
+ # weight settings
453
+ groupsize = 32
454
+ mapping_type = MappingType .ASYMMETRIC
455
+ block_size = (1 , groupsize )
456
+ target_dtype = torch .int32
457
+ quant_min = 0
458
+ quant_max = 15
459
+ eps = 1e-6
460
+ preserve_zero = False
461
+ zero_point_dtype = torch .bfloat16
462
+
463
+ # weight only quantization
464
+ input_quant_func = None
465
+
466
+ # use 1024 so that we don't need padding
467
+ m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
468
+ m_copy = copy .deepcopy (m )
469
+ example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ).to ("cuda" ), m .example_inputs ()))
470
+
471
+ def to_quantized (weight ):
472
+ return AffineQuantizedTensor .from_float (
473
+ weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps ,
474
+ zero_point_dtype = zero_point_dtype ,
475
+ preserve_zero = preserve_zero ,
476
+ zero_point_domain = ZeroPointDomain .FLOAT ,
477
+ input_quant_func = input_quant_func ,
478
+ )
479
+
480
+ m .linear1 .weight = torch .nn .Parameter (to_quantized (m .linear1 .weight ), requires_grad = False )
481
+ m .linear2 .weight = torch .nn .Parameter (to_quantized (m .linear2 .weight ), requires_grad = False )
482
+ assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
483
+ assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
484
+
485
+ # reference
486
+ from torchao .quantization .quant_api import change_linear_weights_to_int4_woqtensors
487
+ change_linear_weights_to_int4_woqtensors (m_copy , groupsize = groupsize )
488
+
489
+ res = m (* example_inputs )
490
+ ref = m_copy (* example_inputs )
491
+
492
+ self .assertTrue (torch .equal (res , ref ))
493
+
494
+
495
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
496
+ def test_quantized_tensor_subclass_int8 (self ):
497
+ from torchao .quantization .subclass import AffineQuantizedTensor
498
+ from torchao .quantization .quant_primitives import MappingType
499
+ import copy
500
+
501
+ # weight settings
502
+ mapping_type = MappingType .SYMMETRIC
503
+ target_dtype = torch .int8
504
+ eps = torch .finfo (torch .float32 ).eps
505
+ zero_point_dtype = torch .int64
506
+
507
+ # weight only quantization
508
+ input_quant_func = None
509
+
510
+ m = ToyLinearModel ().eval ().to (torch .bfloat16 )
511
+ m_copy = copy .deepcopy (m )
512
+ example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ), m .example_inputs ()))
513
+
514
+ def to_quantized (weight ):
515
+ block_size = (1 , weight .shape [1 ])
516
+ return AffineQuantizedTensor .from_float (weight , mapping_type , block_size , target_dtype , eps = eps , zero_point_dtype = zero_point_dtype , input_quant_func = input_quant_func )
517
+
518
+ m .linear1 .weight = torch .nn .Parameter (to_quantized (m .linear1 .weight ), requires_grad = False )
519
+ m .linear2 .weight = torch .nn .Parameter (to_quantized (m .linear2 .weight ), requires_grad = False )
520
+ assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
521
+ assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
522
+
523
+ # reference
524
+ from torchao .quantization .quant_api import change_linear_weights_to_int8_woqtensors
525
+ change_linear_weights_to_int8_woqtensors (m_copy )
526
+
527
+ res = m (* example_inputs )
528
+ ref = m_copy (* example_inputs )
445
529
530
+ torch .testing .assert_close (res , ref , rtol = 0.00001 , atol = 1e-2 )
446
531
447
532
448
533
if __name__ == "__main__" :
0 commit comments