diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 71b01353d5a..5e794a81bd1 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -71,9 +71,26 @@ class OperatorConfig(NamedTuple): valid_func_list: List[Callable] = [] +class TorchBaseConfig(BaseConfig): + # re-write func _get_op_name_op_type_config to fallback op_type with string + # because there are some special op_types for IPEX backend: `Linear&Relu`, `Linear&add`, ... + def _get_op_name_op_type_config(self): + op_type_config_dict = dict() + op_name_config_dict = dict() + for name, config in self.local_config.items(): + if self._is_op_type(name): + # Convert the Callable to String. + new_name = self._op_type_to_str(name) + op_type_config_dict[new_name] = config + else: + op_name_config_dict[name] = config + op_type_config_dict[name] = config + return op_type_config_dict, op_name_config_dict + + ######################## RNT Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN) -class RTNConfig(BaseConfig): +class RTNConfig(TorchBaseConfig): """Config class for round-to-nearest weight-only quantization.""" name = RTN @@ -238,7 +255,7 @@ def get_default_double_quant_config(type="BNB_NF4"): ######################## GPTQ Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=GPTQ, priority=PRIORITY_GPTQ) -class GPTQConfig(BaseConfig): +class GPTQConfig(TorchBaseConfig): """Config class for GPTQ. GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers. @@ -390,7 +407,7 @@ def get_default_gptq_config() -> GPTQConfig: ######################## AWQ Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=AWQ, priority=PRIORITY_AWQ) -class AWQConfig(BaseConfig): +class AWQConfig(TorchBaseConfig): """Config class for AWQ. AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration. @@ -532,7 +549,7 @@ def get_default_awq_config() -> AWQConfig: ######################## TEQ Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=TEQ, priority=PRIORITY_TEQ) -class TEQConfig(BaseConfig): +class TEQConfig(TorchBaseConfig): """Config class for TEQ. TEQ: Activation-aware Weight Quantization for LLM Compression and Acceleration. @@ -670,7 +687,7 @@ def get_default_teq_config() -> TEQConfig: ######################## AUTOROUND Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=AUTOROUND, priority=PRIORITY_AUTOROUND) -class AutoRoundConfig(BaseConfig): +class AutoRoundConfig(TorchBaseConfig): """Config class for AUTOROUND. AUTOROUND: Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs. @@ -815,7 +832,7 @@ def get_default_AutoRound_config() -> AutoRoundConfig: ######################## MX Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=MX_QUANT) -class MXQuantConfig(BaseConfig): +class MXQuantConfig(TorchBaseConfig): """Config class for MX quantization.""" supported_configs: List[OperatorConfig] = [] @@ -928,7 +945,7 @@ def get_default_mx_config() -> MXQuantConfig: ######################## Dynamic Quant Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=PT2E_DYNAMIC_QUANT) -class DynamicQuantConfig(BaseConfig): +class DynamicQuantConfig(TorchBaseConfig): """Config class for dynamic quantization.""" name = PT2E_DYNAMIC_QUANT @@ -1002,7 +1019,7 @@ def get_default_dynamic_config() -> DynamicQuantConfig: ######################## Static Quant Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT) -class StaticQuantConfig(BaseConfig): +class StaticQuantConfig(TorchBaseConfig): """Config class for static quantization.""" name = STATIC_QUANT @@ -1091,7 +1108,7 @@ def get_default_static_config() -> StaticQuantConfig: ######################## Smooth Quant Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT) -class SmoothQuantConfig(BaseConfig): +class SmoothQuantConfig(TorchBaseConfig): """Config class for smooth quantization.""" name = SMOOTH_QUANT @@ -1205,7 +1222,7 @@ def get_default_sq_config() -> SmoothQuantConfig: ######################## HQQ Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=HQQ, priority=PRIORITY_HQQ) -class HQQConfig(BaseConfig): +class HQQConfig(TorchBaseConfig): # Half-Quadratic Quantization (HQQ), more details: # Blog: https://mobiusml.github.io/hqq_blog/ # Code: https://github.com/mobiusml/hqq @@ -1286,7 +1303,7 @@ def get_default_hqq_config() -> HQQConfig: ######################## FP8 Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=FP8_QUANT) -class FP8Config(BaseConfig): +class FP8Config(TorchBaseConfig): """Config class for FP8 quantization.""" name = FP8_QUANT @@ -1381,7 +1398,7 @@ def get_default_fp8_config_set() -> FP8Config: ######################## MixPrecision Config ############################### @register_config(framework_name=FRAMEWORK_NAME, algo_name=MIX_PRECISION) -class MixPrecisionConfig(BaseConfig): +class MixPrecisionConfig(TorchBaseConfig): """Config class for mix-precision.""" name = MIX_PRECISION diff --git a/test/3x/torch/quantization/test_static_quant.py b/test/3x/torch/quantization/test_static_quant.py index 46e791aa52f..cae22ff79ab 100644 --- a/test/3x/torch/quantization/test_static_quant.py +++ b/test/3x/torch/quantization/test_static_quant.py @@ -76,7 +76,7 @@ def test_static_quant_fallback(self): quant_config = get_default_static_config() example_inputs = self.input # fallback by op_type - quant_config.set_local(torch.nn.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) + quant_config.set_local([torch.nn.Linear, "Linear&add"], StaticQuantConfig(w_dtype="fp32", act_dtype="fp32")) prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs) run_fn(prepared_model) q_model = convert(prepared_model)