Skip to content

Commit 31d8bb9

Browse files
mengniwang95Eran Geva
authored andcommitted
Add UT and remove unused code for torch MX quant (#1854)
* Add UT and remove unused code for torch MX quant --------- Change-Id: I2727aa716fa99467fa2d63b966de4d88470e4bb3 Signed-off-by: Mengni Wang <[email protected]>
1 parent 647905a commit 31d8bb9

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

neural_compressor/torch/algorithms/mx_quant/mx.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ def apply_mx_specs(self):
6262
axes=[-1],
6363
)
6464

65-
def append_name(self, postfix):
66-
self.name += postfix
67-
6865
def forward(self, input):
6966
if self.mx_none:
7067
return super().forward(input)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
import torch
3+
4+
from neural_compressor.torch.algorithms.mx_quant import utils
5+
6+
7+
def test_mx_quant_utility():
8+
tensor = torch.rand((1, 30))
9+
assert torch.equal(tensor, utils.quantize_mx_op(tensor, None, "nearest", 32))
10+
assert torch.equal(tensor, utils._quantize_fp(tensor))
11+
assert torch.equal(tensor, utils._quantize_bfloat(tensor, 0))
12+
assert torch.equal(tensor, utils._quantize_mx(tensor, 8, None))
13+
14+
assert not torch.equal(utils._shared_exponents(tensor, "none"), utils._shared_exponents(tensor))
15+
with pytest.raises(Exception):
16+
utils._shared_exponents(tensor, None)
17+
with pytest.raises(Exception):
18+
utils._reshape_to_blocks(tensor, None, 32)
19+
with pytest.raises(Exception):
20+
utils.quantize_elemwise_op(tensor, "test")
21+
with pytest.raises(Exception):
22+
utils._round_mantissa(tensor, 3, "test")

test/3x/torch/quantization/test_mx_quant.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55

6-
from neural_compressor.torch.quantization import MXQuantConfig, get_default_mx_config, quantize
6+
from neural_compressor.torch.quantization import MXQuantConfig, convert, get_default_mx_config, prepare
77

88

99
def build_simple_torch_model():
@@ -40,20 +40,35 @@ def teardown_class(self):
4040
def test_mx_quant_default(self):
4141
fp32_model = copy.deepcopy(self.fp32_model)
4242
quant_config = get_default_mx_config()
43-
q_model = quantize(fp32_model, quant_config=quant_config)
43+
fp32_model = prepare(model=fp32_model, quant_config=quant_config)
44+
q_model = convert(model=fp32_model)
4445
assert q_model is not None, "Quantization failed!"
4546

4647
@pytest.mark.parametrize(
47-
"w_dtype, weight_only",
48+
"w_dtype, weight_only, round_method, out_dtype",
4849
[
49-
("fp4", True),
50-
("fp8_e5m2", False),
50+
("fp4", True, "dither", "float32"),
51+
("fp8_e5m2", False, "floor", "bfloat16"),
52+
("int8", False, "even", "float16"),
53+
("int4", False, "nearest", "float32"),
54+
("int2", False, "dither", "bfloat16"),
55+
("fp8_e4m3", False, "floor", "float16"),
56+
("fp6_e3m2", False, "even", "float32"),
57+
("fp6_e2m3", False, "nearest", "bfloat16"),
58+
("float16", False, "dither", "float16"),
59+
("bfloat16", False, "floor", "float32"),
5160
],
5261
)
53-
def test_mx_quant_params(self, w_dtype, weight_only):
62+
def test_mx_quant_params(self, w_dtype, weight_only, round_method, out_dtype):
5463
fp32_model = copy.deepcopy(self.fp32_model)
55-
quant_config = MXQuantConfig(w_dtype=w_dtype, weight_only=weight_only)
56-
q_model = quantize(fp32_model, quant_config=quant_config)
64+
quant_config = MXQuantConfig(
65+
w_dtype=w_dtype,
66+
weight_only=weight_only,
67+
round_method=round_method,
68+
out_dtype=out_dtype,
69+
)
70+
fp32_model = prepare(model=fp32_model, quant_config=quant_config)
71+
q_model = convert(model=fp32_model)
5772
assert q_model is not None, "Quantization failed!"
5873

5974
def test_mx_quant_accuracy(self):
@@ -72,8 +87,10 @@ def forward(self, x):
7287
fp32_model = copy.deepcopy(model)
7388
fp32_model.linear.weight = torch.nn.Parameter(torch.tensor([[0.0, 1.0], [1.0, 0.0]]))
7489
example_inputs = torch.zeros(3, 2)
90+
7591
quant_config = MXQuantConfig()
76-
q_model = quantize(fp32_model, quant_config=quant_config)
92+
fp32_model = prepare(model=fp32_model, quant_config=quant_config)
93+
q_model = convert(model=fp32_model)
7794
output1 = fp32_model(example_inputs)
7895
output2 = q_model(example_inputs)
7996
# set a big atol to avoid random issue

0 commit comments

Comments
 (0)