diff --git a/neural_compressor/torch/__init__.py b/neural_compressor/torch/__init__.py index e60a2a2c2ec..c96fdddb993 100644 --- a/neural_compressor/torch/__init__.py +++ b/neural_compressor/torch/__init__.py @@ -28,4 +28,4 @@ ) from neural_compressor.common.base_tuning import TuningConfig -from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set +from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set, get_rtn_double_quant_config_set diff --git a/neural_compressor/torch/quantization/__init__.py b/neural_compressor/torch/quantization/__init__.py index b287c5ae2d6..5406705e0e7 100644 --- a/neural_compressor/torch/quantization/__init__.py +++ b/neural_compressor/torch/quantization/__init__.py @@ -23,3 +23,4 @@ SmoothQuantConfig, get_default_sq_config, ) +from neural_compressor.torch.quantization.autotune import get_rtn_double_quant_config_set, get_all_config_set diff --git a/neural_compressor/torch/quantization/autotune.py b/neural_compressor/torch/quantization/autotune.py index bb48f0685c6..794f7bb5b3c 100644 --- a/neural_compressor/torch/quantization/autotune.py +++ b/neural_compressor/torch/quantization/autotune.py @@ -21,21 +21,26 @@ from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning from neural_compressor.torch import quantize -from neural_compressor.torch.quantization.config import FRAMEWORK_NAME +from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig +from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS logger = Logger().get_logger() -__all__ = [ - "autotune", - "get_all_config_set", -] +__all__ = ["autotune", "get_all_config_set", "get_rtn_double_quant_config_set"] def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]: return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME) +def get_rtn_double_quant_config_set() -> List[RTNConfig]: + rtn_double_quant_config_set = [] + for double_quant_type, double_quant_config in DOUBLE_QUANT_CONFIGS.items(): + rtn_double_quant_config_set.append(RTNConfig.from_dict(double_quant_config)) + return rtn_double_quant_config_set + + def autotune( model: torch.nn.Module, tune_config: TuningConfig, diff --git a/test/3x/torch/test_autotune.py b/test/3x/torch/test_autotune.py index e1b717e3163..bb2115aae85 100644 --- a/test/3x/torch/test_autotune.py +++ b/test/3x/torch/test_autotune.py @@ -207,6 +207,23 @@ def test_autotune_not_eval_func(self): str(context.exception), "Please ensure that you register at least one evaluation metric for auto-tune." ) + @reset_tuning_target + def test_rtn_double_quant_config_set(self) -> None: + from neural_compressor.torch import RTNConfig, TuningConfig, autotune, get_rtn_double_quant_config_set + from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS + + rtn_double_quant_config_set = get_rtn_double_quant_config_set() + self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS)) + + def eval_acc_fn(model) -> float: + return 1.0 + + custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), max_trials=2) + best_model = autotune( + model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}] + ) + self.assertIsNotNone(best_model) + if __name__ == "__main__": unittest.main()