|
1 | 1 | import unittest
|
2 | 2 | from functools import wraps
|
| 3 | +from unittest.mock import patch |
3 | 4 |
|
4 | 5 | import torch
|
5 | 6 | import transformers
|
6 | 7 |
|
7 | 8 | from neural_compressor.torch.algorithms.weight_only.gptq import DataloaderPreprocessor
|
8 | 9 | from neural_compressor.torch.quantization import RTNConfig, TuningConfig, autotune, get_all_config_set
|
9 |
| -from neural_compressor.torch.utils import logger |
| 10 | +from neural_compressor.torch.utils import constants, logger |
| 11 | + |
| 12 | +FAKE_DOUBLE_QUANT_CONFIGS = { |
| 13 | + "BNB_NF4": { |
| 14 | + "dtype": "nf4", |
| 15 | + "bits": 4, |
| 16 | + "group_size": 32, |
| 17 | + "use_double_quant": True, |
| 18 | + "double_quant_bits": 8, |
| 19 | + "double_quant_dtype": "int", |
| 20 | + "double_quant_use_sym": False, |
| 21 | + "double_quant_group_size": 256, |
| 22 | + }, |
| 23 | + "GGML_TYPE_Q4_K": { |
| 24 | + "dtype": "int", |
| 25 | + "bits": 4, |
| 26 | + "use_sym": False, |
| 27 | + "group_size": 32, |
| 28 | + "use_double_quant": True, |
| 29 | + "double_quant_bits": 6, |
| 30 | + "double_quant_dtype": "int", |
| 31 | + "double_quant_use_sym": True, |
| 32 | + "double_quant_group_size": 8, |
| 33 | + }, |
| 34 | +} |
10 | 35 |
|
11 | 36 |
|
12 | 37 | def reset_tuning_target(test_func):
|
@@ -239,6 +264,59 @@ def eval_acc_fn(model):
|
239 | 264 | best_model = autotune(model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=eval_acc_fn)
|
240 | 265 | self.assertIsNone(best_model)
|
241 | 266 |
|
| 267 | + @reset_tuning_target |
| 268 | + def test_rtn_double_quant_config_set(self) -> None: |
| 269 | + from neural_compressor.torch.quantization import TuningConfig, autotune, get_rtn_double_quant_config_set |
| 270 | + from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS |
| 271 | + |
| 272 | + rtn_double_quant_config_set = get_rtn_double_quant_config_set() |
| 273 | + self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS)) |
| 274 | + |
| 275 | + def eval_acc_fn(model) -> float: |
| 276 | + return 1.0 |
| 277 | + |
| 278 | + custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), max_trials=10) |
| 279 | + best_model = autotune( |
| 280 | + model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] |
| 281 | + ) |
| 282 | + self.assertIsNotNone(best_model) |
| 283 | + |
| 284 | + @reset_tuning_target |
| 285 | + def test_rtn_double_quant_config_set2(self) -> None: |
| 286 | + from neural_compressor.torch.quantization import TuningConfig, autotune, get_rtn_double_quant_config_set |
| 287 | + from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS |
| 288 | + |
| 289 | + rtn_double_quant_config_set = get_rtn_double_quant_config_set() |
| 290 | + self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS)) |
| 291 | + |
| 292 | + def eval_acc_fn(model) -> float: |
| 293 | + return 1.0 |
| 294 | + |
| 295 | + custom_tune_config = TuningConfig( |
| 296 | + config_set=get_rtn_double_quant_config_set(), max_trials=10, tolerable_loss=-1 |
| 297 | + ) |
| 298 | + best_model = autotune( |
| 299 | + model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] |
| 300 | + ) |
| 301 | + self.assertIsNone(best_model) |
| 302 | + |
| 303 | + @patch("neural_compressor.torch.utils.constants.DOUBLE_QUANT_CONFIGS", FAKE_DOUBLE_QUANT_CONFIGS) |
| 304 | + def test_rtn_double_quant_config_set3(self) -> None: |
| 305 | + from neural_compressor.torch.quantization import get_rtn_double_quant_config_set |
| 306 | + |
| 307 | + rtn_double_quant_config_set = get_rtn_double_quant_config_set() |
| 308 | + print(len(rtn_double_quant_config_set)) |
| 309 | + self.assertEqual(len(constants.DOUBLE_QUANT_CONFIGS), len(FAKE_DOUBLE_QUANT_CONFIGS)) |
| 310 | + |
| 311 | + def eval_acc_fn(model) -> float: |
| 312 | + return 1.0 |
| 313 | + |
| 314 | + custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), tolerable_loss=-1) |
| 315 | + best_model = autotune( |
| 316 | + model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] |
| 317 | + ) |
| 318 | + self.assertIsNone(best_model) |
| 319 | + |
242 | 320 |
|
243 | 321 | if __name__ == "__main__":
|
244 | 322 | unittest.main()
|
0 commit comments