Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion neural_compressor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
)

from neural_compressor.common.base_tuning import TuningConfig
from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set
from neural_compressor.torch.quantization.autotune import autotune, get_all_config_set, get_rtn_double_quant_config_set
1 change: 1 addition & 0 deletions neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
SmoothQuantConfig,
get_default_sq_config,
)
from neural_compressor.torch.quantization.autotune import get_rtn_double_quant_config_set, get_all_config_set
15 changes: 10 additions & 5 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,26 @@
from neural_compressor.common.base_config import BaseConfig, get_all_config_set_from_config_registry
from neural_compressor.common.base_tuning import TuningConfig, evaluator, init_tuning
from neural_compressor.torch import quantize
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig
from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS

logger = Logger().get_logger()


__all__ = [
"autotune",
"get_all_config_set",
]
__all__ = ["autotune", "get_all_config_set", "get_rtn_double_quant_config_set"]


def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)


def get_rtn_double_quant_config_set() -> List[RTNConfig]:
rtn_double_quant_config_set = []
for double_quant_type, double_quant_config in DOUBLE_QUANT_CONFIGS.items():
rtn_double_quant_config_set.append(RTNConfig.from_dict(double_quant_config))
return rtn_double_quant_config_set


def autotune(
model: torch.nn.Module,
tune_config: TuningConfig,
Expand Down
17 changes: 17 additions & 0 deletions test/3x/torch/test_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,23 @@ def test_autotune_not_eval_func(self):
str(context.exception), "Please ensure that you register at least one evaluation metric for auto-tune."
)

@reset_tuning_target
def test_rtn_double_quant_config_set(self) -> None:
from neural_compressor.torch import RTNConfig, TuningConfig, autotune, get_rtn_double_quant_config_set
from neural_compressor.torch.utils.constants import DOUBLE_QUANT_CONFIGS

rtn_double_quant_config_set = get_rtn_double_quant_config_set()
self.assertEqual(len(rtn_double_quant_config_set), len(DOUBLE_QUANT_CONFIGS))

def eval_acc_fn(model) -> float:
return 1.0

custom_tune_config = TuningConfig(config_set=get_rtn_double_quant_config_set(), max_trials=2)
best_model = autotune(
model=build_simple_torch_model(), tune_config=custom_tune_config, eval_fns=[{"eval_fn": eval_acc_fn}]
)
self.assertIsNotNone(best_model)


if __name__ == "__main__":
unittest.main()