3
3
import pytest
4
4
import torch
5
5
6
- from neural_compressor .torch .quantization import MXQuantConfig , get_default_mx_config , quantize
6
+ from neural_compressor .torch .quantization import MXQuantConfig , convert , get_default_mx_config , prepare
7
7
8
8
9
9
def build_simple_torch_model ():
@@ -40,20 +40,35 @@ def teardown_class(self):
40
40
def test_mx_quant_default (self ):
41
41
fp32_model = copy .deepcopy (self .fp32_model )
42
42
quant_config = get_default_mx_config ()
43
- q_model = quantize (fp32_model , quant_config = quant_config )
43
+ fp32_model = prepare (model = fp32_model , quant_config = quant_config )
44
+ q_model = convert (model = fp32_model )
44
45
assert q_model is not None , "Quantization failed!"
45
46
46
47
@pytest .mark .parametrize (
47
- "w_dtype, weight_only" ,
48
+ "w_dtype, weight_only, round_method, out_dtype " ,
48
49
[
49
- ("fp4" , True ),
50
- ("fp8_e5m2" , False ),
50
+ ("fp4" , True , "dither" , "float32" ),
51
+ ("fp8_e5m2" , False , "floor" , "bfloat16" ),
52
+ ("int8" , False , "even" , "float16" ),
53
+ ("int4" , False , "nearest" , "float32" ),
54
+ ("int2" , False , "dither" , "bfloat16" ),
55
+ ("fp8_e4m3" , False , "floor" , "float16" ),
56
+ ("fp6_e3m2" , False , "even" , "float32" ),
57
+ ("fp6_e2m3" , False , "nearest" , "bfloat16" ),
58
+ ("float16" , False , "dither" , "float16" ),
59
+ ("bfloat16" , False , "floor" , "float32" ),
51
60
],
52
61
)
53
- def test_mx_quant_params (self , w_dtype , weight_only ):
62
+ def test_mx_quant_params (self , w_dtype , weight_only , round_method , out_dtype ):
54
63
fp32_model = copy .deepcopy (self .fp32_model )
55
- quant_config = MXQuantConfig (w_dtype = w_dtype , weight_only = weight_only )
56
- q_model = quantize (fp32_model , quant_config = quant_config )
64
+ quant_config = MXQuantConfig (
65
+ w_dtype = w_dtype ,
66
+ weight_only = weight_only ,
67
+ round_method = round_method ,
68
+ out_dtype = out_dtype ,
69
+ )
70
+ fp32_model = prepare (model = fp32_model , quant_config = quant_config )
71
+ q_model = convert (model = fp32_model )
57
72
assert q_model is not None , "Quantization failed!"
58
73
59
74
def test_mx_quant_accuracy (self ):
@@ -72,8 +87,10 @@ def forward(self, x):
72
87
fp32_model = copy .deepcopy (model )
73
88
fp32_model .linear .weight = torch .nn .Parameter (torch .tensor ([[0.0 , 1.0 ], [1.0 , 0.0 ]]))
74
89
example_inputs = torch .zeros (3 , 2 )
90
+
75
91
quant_config = MXQuantConfig ()
76
- q_model = quantize (fp32_model , quant_config = quant_config )
92
+ fp32_model = prepare (model = fp32_model , quant_config = quant_config )
93
+ q_model = convert (model = fp32_model )
77
94
output1 = fp32_model (example_inputs )
78
95
output2 = q_model (example_inputs )
79
96
# set a big atol to avoid random issue
0 commit comments