Skip to content

Commit b8d98eb

Browse files
authored
Support double quant tuning (#1591)
Signed-off-by: yiliu30 <[email protected]>
1 parent e7b3478 commit b8d98eb

File tree

4 files changed

+99
-7
lines changed

4 files changed

+99
-7
lines changed

neural_compressor/torch/quantization/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
)
2727

2828
# TODO(Yi): move config to config.py
29-
from neural_compressor.torch.quantization.autotune import autotune, TuningConfig, get_all_config_set
29+
from neural_compressor.torch.quantization.autotune import (
30+
autotune,
31+
TuningConfig,
32+
get_all_config_set,
33+
get_rtn_double_quant_config_set,
34+
)
3035

3136
### Quantization Function Registration ###
3237
import neural_compressor.torch.quantization.algorithm_entry

neural_compressor/torch/quantization/autotune.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,26 @@
1717

1818
import torch
1919

20-
from neural_compressor.common import Logger
2120
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
2221
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
2322
from neural_compressor.torch.quantization import quantize
24-
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME
25-
from neural_compressor.torch.utils import logger
23+
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig
24+
from neural_compressor.torch.utils import constants, logger
2625

2726
__all__ = [
2827
"autotune",
2928
"get_all_config_set",
29+
"get_rtn_double_quant_config_set",
3030
]
3131

3232

33+
def get_rtn_double_quant_config_set() -> List[RTNConfig]:
34+
rtn_double_quant_config_set = []
35+
for double_quant_type, double_quant_config in constants.DOUBLE_QUANT_CONFIGS.items():
36+
rtn_double_quant_config_set.append(RTNConfig.from_dict(double_quant_config))
37+
return rtn_double_quant_config_set
38+
39+
3340
def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
3441
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)
3542

@@ -52,7 +59,7 @@ def autotune(
5259
for trial_index, quant_config in enumerate(config_loader):
5360
tuning_logger.trial_start(trial_index=trial_index)
5461
tuning_logger.quantization_start()
55-
logger.info(f"quant config: {quant_config}")
62+
logger.info(quant_config.to_dict())
5663
# !!! Make sure to use deepcopy only when inplace is set to `True`.
5764
q_model = quantize(deepcopy(model), quant_config=quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
5865
tuning_logger.quantization_end()
@@ -62,6 +69,7 @@ def autotune(
6269
tuning_monitor.add_trial_result(trial_index, eval_result, quant_config)
6370
tuning_logger.trial_end(trial_index)
6471
if tuning_monitor.need_stop():
72+
logger.info("Stopped tuning.")
6573
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
6674
# !!! Make sure to use deepcopy only when inplace is set to `True`.
6775
quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True)

neural_compressor/torch/quantization/quantize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def quantize(
5555
assert isinstance(
5656
quant_config, BaseConfig
5757
), f"Please pass a dict or config instance as the quantization configuration, but got {type(quant_config)}."
58-
logger.info(f"Quantize model with config: \n {quant_config.to_json_string()} \n")
58+
logger.info("Quantize model with config:")
59+
logger.info(quant_config.to_dict())
5960
# select quantization algo according to config
6061

6162
model_info = quant_config.get_model_info(model=q_model)

test/3x/torch/test_autotune.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,37 @@
11
import unittest
22
from functools import wraps
3+
from unittest.mock import patch
34

45
import torch
56
import transformers
67

78
from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor
89
from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune, get_all_config_set
9-
from neural_compressor.torch.utils import logger
10+
from neural_compressor.torch.utils import constants, logger
11+
12+
FAKE_DOUBLE_QUANT_CONFIGS = {
13+
"BNB_NF4": {
14+
"dtype": "nf4",
15+
"bits": 4,
16+
"group_size": 32,
17+
"use_double_quant": True,
18+
"double_quant_bits": 8,
19+
"double_quant_dtype": "int",
20+
"double_quant_use_sym": False,
21+
"double_quant_group_size": 256,
22+
},
23+
"GGML_TYPE_Q4_K": {
24+
"dtype": "int",
25+
"bits": 4,
26+
"use_sym": False,
27+
"group_size": 32,
28+
"use_double_quant": True,
29+
"double_quant_bits": 6,
30+
"double_quant_dtype": "int",
31+
"double_quant_use_sym": True,
32+
"double_quant_group_size": 8,
33+
},
34+
}
1035

1136

1237
def reset_tuning_target(test_func):
@@ -239,6 +264,59 @@ def eval_acc_fn(model):
239264
best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_acc_fn)
240265
self.assertIsNone(best_model)
241266

267+
@reset_tuning_target
268+
def test_rtn_double_quant_config_set(self) -> None:
269+
from neural_compressor.torch.quantization import TuningConfig, autotune, get_rtn_double_quant_config_set
270+
from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS
271+
272+
rtn_double_quant_config_set = get_rtn_double_quant_config_set()
273+
self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS))
274+
275+
def eval_acc_fn(model) -> float:
276+
return 1.0
277+
278+
custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), max_trials=10)
279+
best_model = autotune(
280+
model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}]
281+
)
282+
self.assertIsNotNone(best_model)
283+
284+
@reset_tuning_target
285+
def test_rtn_double_quant_config_set2(self) -> None:
286+
from neural_compressor.torch.quantization import TuningConfig, autotune, get_rtn_double_quant_config_set
287+
from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS
288+
289+
rtn_double_quant_config_set = get_rtn_double_quant_config_set()
290+
self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS))
291+
292+
def eval_acc_fn(model) -> float:
293+
return 1.0
294+
295+
custom_tune_config = TuningConfig(
296+
config_set=get_rtn_double_quant_config_set(), max_trials=10, tolerable_loss=-1
297+
)
298+
best_model = autotune(
299+
model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}]
300+
)
301+
self.assertIsNone(best_model)
302+
303+
@patch("neural_compressor.torch.utils.constants.DOUBLE_QUANT_CONFIGS", FAKE_DOUBLE_QUANT_CONFIGS)
304+
def test_rtn_double_quant_config_set3(self) -> None:
305+
from neural_compressor.torch.quantization import get_rtn_double_quant_config_set
306+
307+
rtn_double_quant_config_set = get_rtn_double_quant_config_set()
308+
print(len(rtn_double_quant_config_set))
309+
self.assertEqual(len(constants.DOUBLE_QUANT_CONFIGS), len(FAKE_DOUBLE_QUANT_CONFIGS))
310+
311+
def eval_acc_fn(model) -> float:
312+
return 1.0
313+
314+
custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), tolerable_loss=-1)
315+
best_model = autotune(
316+
model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}]
317+
)
318+
self.assertIsNone(best_model)
319+
242320

243321
if __name__ == "__main__":
244322
unittest.main()

0 commit comments

Comments
 (0)