@@ -44,6 +44,14 @@ class MeasureExclude(Flag):
44
44
PARAMS = auto ()
45
45
ALL = auto ()
46
46
47
+ class SupportedFp8 (Enum ):
48
+ E4M3 = torch .float8_e4m3fn
49
+ E5M2 = torch .float8_e5m2
50
+
51
+ class HpDtype (Enum ):
52
+ BF16 = torch .bfloat16
53
+ FP16 = torch .float16
54
+ FP32 = torch .float32
47
55
48
56
class ScaleMethod (Enum ):
49
57
MAX = 1
@@ -69,6 +77,13 @@ def set_hqt_config(mod, config):
69
77
mod .__hqt_config__ = config
70
78
71
79
80
+ def _get_enum_from_string (EnumClass , str , key ):
81
+ if not hasattr (EnumClass , str .upper ()):
82
+ raise ValueError (
83
+ f"Invalid '{ key } ' value in custom config ('{ str } '). Enter one of { [m .name for m in EnumClass ]} " )
84
+ return EnumClass [str .upper ()]
85
+
86
+
72
87
@dataclass
73
88
class Fp8cfg :
74
89
cfg : Mapping [str , Any ]
@@ -84,7 +99,7 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
84
99
}, # types and names to not be quantized
85
100
"allowlist" : {
86
101
"names" : [],
87
- "types" : ("torch.nn.Linear" , "torch.nn.Conv2d" , "BMM" ),
102
+ "types" : (),
88
103
}, # types and names to be quantized. Allowlist by names is not yet implemented
89
104
"mode" : QuantMode .QUANTIZE , # Quantize or Measure
90
105
"scale_method" : ScaleMethod .UNIT_SCALE , # Method to quantize with
@@ -104,79 +119,19 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg:
104
119
# go over all user-defined keys from json, handle various cases
105
120
for keys in custom_config :
106
121
if keys == "mode" :
107
- if custom_config [keys ] == "NONE" :
108
- custom_config [keys ] = QuantMode .NONE
109
- elif custom_config [keys ] == "QUANTIZE" :
110
- custom_config [keys ] = QuantMode .QUANTIZE
111
- elif custom_config [keys ] == "MEASURE" :
112
- custom_config [keys ] = QuantMode .MEASURE
113
- elif custom_config [keys ] == "SHAPE" :
114
- custom_config [keys ] = QuantMode .SHAPE
115
- else :
116
- raise ValueError ("invalid mode in custom config. Enter Quantize or Measure" )
122
+ custom_config [keys ] = _get_enum_from_string (QuantMode , custom_config [keys ], keys )
117
123
118
124
if keys == "measure_exclude" :
119
- if custom_config [keys ] == "NONE" :
120
- custom_config [keys ] = MeasureExclude .NONE
121
- elif custom_config [keys ] == "OUTPUT" :
122
- custom_config [keys ] = MeasureExclude .OUTPUT
123
- elif custom_config [keys ] == "INPUT" :
124
- custom_config [keys ] = MeasureExclude .INPUT
125
- elif custom_config [keys ] == "ALL" :
126
- custom_config [keys ] = MeasureExclude .ALL
127
- else :
128
- raise ValueError ("invalid measure exclude value in custom config. Enter OUTPUT or NONE" )
125
+ custom_config [keys ] = _get_enum_from_string (MeasureExclude , custom_config [keys ], keys )
129
126
130
127
if keys == "fp8_config" :
131
- if custom_config [keys ].lower () == "e4m3" :
132
- custom_config [keys ] = torch .float8_e4m3fn
133
-
134
- elif custom_config [keys ].lower () == "e5m2" :
135
- custom_config [keys ] = torch .float8_e5m2
136
- else :
137
- raise ValueError ("invalid fp8_config in custom config. Enter E4M3 or E5M2" )
128
+ custom_config [keys ] = _get_enum_from_string (SupportedFp8 , custom_config [keys ], keys ).value
138
129
139
130
if keys == "hp_dtype" :
140
- if custom_config [keys ].lower () == "bf16" :
141
- custom_config [keys ] = torch .bfloat16
142
- elif custom_config [keys ].lower () == "fp16" :
143
- custom_config [keys ] = torch .float16
144
- elif custom_config [keys ].lower () == "fp32" :
145
- custom_config [keys ] = torch .float32
146
- else :
147
- raise ValueError ("invalid hp_dtype in custom config. Enter bf16, fp16 or fp32" )
131
+ custom_config [keys ] = _get_enum_from_string (HpDtype , custom_config [keys ], keys ).value
148
132
149
133
if keys == "scale_method" :
150
- if custom_config [keys ].lower () == "unit_scale" :
151
- custom_config [keys ] = ScaleMethod .UNIT_SCALE
152
- elif custom_config [keys ].lower () == "max" :
153
- custom_config [keys ] = ScaleMethod .MAX
154
- elif custom_config [keys ].lower () == "maxabs_hw" :
155
- custom_config [keys ] = ScaleMethod .MAXABS_HW
156
- elif custom_config [keys ].lower () == "maxabs_pow2" :
157
- custom_config [keys ] = ScaleMethod .MAXABS_POW2
158
- elif custom_config [keys ].lower () == "maxabs_hw_opt_weight" :
159
- custom_config [keys ] = ScaleMethod .MAXABS_HW_OPT_WEIGHT
160
- elif custom_config [keys ].lower () == "maxabs_pow2_opt_weight" :
161
- custom_config [keys ] = ScaleMethod .MAXABS_POW2_OPT_WEIGHT
162
- elif custom_config [keys ].lower () == "smoothquant_weights_output_channel_maxabs_pow2" :
163
- custom_config [keys ] = ScaleMethod .SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2
164
- elif custom_config [keys ].lower () == "weaksmoothquant_weights_output_channel_maxabs_pow2" :
165
- custom_config [keys ] = ScaleMethod .WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2
166
- elif custom_config [keys ].lower () == "act_maxabs_hw_weights_pcs_maxabs_pow2" :
167
- custom_config [keys ] = ScaleMethod .ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2
168
- elif custom_config [keys ].lower () == "act_maxabs_hw_weights_pcs_opt_pow2" :
169
- custom_config [keys ] = ScaleMethod .ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2
170
- elif custom_config [keys ].lower () == "act_maxabs_pow2_weights_pcs_maxabs_pow2" :
171
- custom_config [keys ] = ScaleMethod .ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2
172
- elif custom_config [keys ].lower () == "act_maxabs_pow2_weights_pcs_opt_pow2" :
173
- custom_config [keys ] = ScaleMethod .ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2
174
- elif custom_config [keys ].lower () == "smoothquant_opt" :
175
- custom_config [keys ] = ScaleMethod .SMOOTHQUANT_OPT
176
- else :
177
- raise ValueError (
178
- f'Invalid fp8_config in custom config ({ custom_config [keys ]} ). should be in ["max", "unit_scale", "maxabs_hw", "maxabs_pow2", "maxabs_per_channel_pow2", "smoothquant_opt"]'
179
- )
134
+ custom_config [keys ] = _get_enum_from_string (ScaleMethod , custom_config [keys ], keys )
180
135
181
136
if keys == "ignore_modules_wo_measures" :
182
137
custom_config [keys ] = custom_config [keys ].lower () == "true"
0 commit comments