Skip to content

Commit 55e1387

Browse files
Tiefen-boopEran Geva
authored andcommitted
[SW-186675] Update default configuration of 'allowlist'
Defined default allowlist types to be empty - allows quantization of all models Refactor parse function to more dynamic code and consistency Change-Id: I6c8a14cb7ca6830927e5c5b7476e4b03335456aa
1 parent 3f1d5c0 commit 55e1387

File tree

1 file changed

+21
-66
lines changed

1 file changed

+21
-66
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py

Lines changed: 21 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ class MeasureExclude(Flag):
4444
PARAMS = auto()
4545
ALL = auto()
4646

47+
class SupportedFp8(Enum):
48+
E4M3 = torch.float8_e4m3fn
49+
E5M2 = torch.float8_e5m2
50+
51+
class HpDtype(Enum):
52+
BF16 = torch.bfloat16
53+
FP16 = torch.float16
54+
FP32 = torch.float32
4755

4856
class ScaleMethod(Enum):
4957
MAX = 1
@@ -69,6 +77,13 @@ def set_hqt_config(mod, config):
6977
mod.__hqt_config__ = config
7078

7179

80+
def _get_enum_from_string(EnumClass, str, key):
81+
if not hasattr(EnumClass, str.upper()):
82+
raise ValueError(
83+
f"Invalid '{key}' value in custom config ('{str}'). Enter one of {[m.name for m in EnumClass]}")
84+
return EnumClass[str.upper()]
85+
86+
7287
@dataclass
7388
class Fp8cfg:
7489
cfg: Mapping[str, Any]
@@ -84,7 +99,7 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
8499
}, # types and names to not be quantized
85100
"allowlist": {
86101
"names": [],
87-
"types": ("torch.nn.Linear", "torch.nn.Conv2d", "BMM"),
102+
"types": (),
88103
}, # types and names to be quantized. Allowlist by names is not yet implemented
89104
"mode": QuantMode.QUANTIZE, # Quantize or Measure
90105
"scale_method": ScaleMethod.UNIT_SCALE, # Method to quantize with
@@ -104,79 +119,19 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
104119
# go over all user-defined keys from json, handle various cases
105120
for keys in custom_config:
106121
if keys == "mode":
107-
if custom_config[keys] == "NONE":
108-
custom_config[keys] = QuantMode.NONE
109-
elif custom_config[keys] == "QUANTIZE":
110-
custom_config[keys] = QuantMode.QUANTIZE
111-
elif custom_config[keys] == "MEASURE":
112-
custom_config[keys] = QuantMode.MEASURE
113-
elif custom_config[keys] == "SHAPE":
114-
custom_config[keys] = QuantMode.SHAPE
115-
else:
116-
raise ValueError("invalid mode in custom config. Enter Quantize or Measure")
122+
custom_config[keys] = _get_enum_from_string(QuantMode, custom_config[keys], keys)
117123

118124
if keys == "measure_exclude":
119-
if custom_config[keys] == "NONE":
120-
custom_config[keys] = MeasureExclude.NONE
121-
elif custom_config[keys] == "OUTPUT":
122-
custom_config[keys] = MeasureExclude.OUTPUT
123-
elif custom_config[keys] == "INPUT":
124-
custom_config[keys] = MeasureExclude.INPUT
125-
elif custom_config[keys] == "ALL":
126-
custom_config[keys] = MeasureExclude.ALL
127-
else:
128-
raise ValueError("invalid measure exclude value in custom config. Enter OUTPUT or NONE")
125+
custom_config[keys] = _get_enum_from_string(MeasureExclude, custom_config[keys], keys)
129126

130127
if keys == "fp8_config":
131-
if custom_config[keys].lower() == "e4m3":
132-
custom_config[keys] = torch.float8_e4m3fn
133-
134-
elif custom_config[keys].lower() == "e5m2":
135-
custom_config[keys] = torch.float8_e5m2
136-
else:
137-
raise ValueError("invalid fp8_config in custom config. Enter E4M3 or E5M2")
128+
custom_config[keys] = _get_enum_from_string(SupportedFp8, custom_config[keys], keys).value
138129

139130
if keys == "hp_dtype":
140-
if custom_config[keys].lower() == "bf16":
141-
custom_config[keys] = torch.bfloat16
142-
elif custom_config[keys].lower() == "fp16":
143-
custom_config[keys] = torch.float16
144-
elif custom_config[keys].lower() == "fp32":
145-
custom_config[keys] = torch.float32
146-
else:
147-
raise ValueError("invalid hp_dtype in custom config. Enter bf16, fp16 or fp32")
131+
custom_config[keys] = _get_enum_from_string(HpDtype, custom_config[keys], keys).value
148132

149133
if keys == "scale_method":
150-
if custom_config[keys].lower() == "unit_scale":
151-
custom_config[keys] = ScaleMethod.UNIT_SCALE
152-
elif custom_config[keys].lower() == "max":
153-
custom_config[keys] = ScaleMethod.MAX
154-
elif custom_config[keys].lower() == "maxabs_hw":
155-
custom_config[keys] = ScaleMethod.MAXABS_HW
156-
elif custom_config[keys].lower() == "maxabs_pow2":
157-
custom_config[keys] = ScaleMethod.MAXABS_POW2
158-
elif custom_config[keys].lower() == "maxabs_hw_opt_weight":
159-
custom_config[keys] = ScaleMethod.MAXABS_HW_OPT_WEIGHT
160-
elif custom_config[keys].lower() == "maxabs_pow2_opt_weight":
161-
custom_config[keys] = ScaleMethod.MAXABS_POW2_OPT_WEIGHT
162-
elif custom_config[keys].lower() == "smoothquant_weights_output_channel_maxabs_pow2":
163-
custom_config[keys] = ScaleMethod.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2
164-
elif custom_config[keys].lower() == "weaksmoothquant_weights_output_channel_maxabs_pow2":
165-
custom_config[keys] = ScaleMethod.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2
166-
elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_maxabs_pow2":
167-
custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2
168-
elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_opt_pow2":
169-
custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2
170-
elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_maxabs_pow2":
171-
custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2
172-
elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_opt_pow2":
173-
custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2
174-
elif custom_config[keys].lower() == "smoothquant_opt":
175-
custom_config[keys] = ScaleMethod.SMOOTHQUANT_OPT
176-
else:
177-
raise ValueError(
178-
f'Invalid fp8_config in custom config ({custom_config[keys]}). should be in ["max", "unit_scale", "maxabs_hw", "maxabs_pow2", "maxabs_per_channel_pow2", "smoothquant_opt"]'
179-
)
134+
custom_config[keys] = _get_enum_from_string(ScaleMethod, custom_config[keys], keys)
180135

181136
if keys == "ignore_modules_wo_measures":
182137
custom_config[keys] = custom_config[keys].lower() == "true"

0 commit comments

Comments
 (0)