30
30
StaticQuantConfig ,
31
31
TEQConfig ,
32
32
)
33
- from neural_compressor .torch .utils import Mode , is_ipex_imported , logger , register_algo
33
+ from neural_compressor .torch .utils import (
34
+ Mode ,
35
+ get_quantizer ,
36
+ is_ipex_imported ,
37
+ logger ,
38
+ postprocess_model ,
39
+ register_algo ,
40
+ )
34
41
from neural_compressor .torch .utils .constants import PT2E_STATIC_QUANT
35
42
36
43
@@ -69,17 +76,9 @@ def rtn_entry(
69
76
"double_quant_group_size" : quant_config .double_quant_group_size ,
70
77
}
71
78
72
- if getattr (model , "quantizer" , False ):
73
- quantizer = model .quantizer
74
- else :
75
- quantizer = RTNQuantizer (quant_config = weight_config )
76
-
79
+ quantizer = get_quantizer (model , quantizer_cls = RTNQuantizer , quant_config = weight_config )
77
80
model = quantizer .execute (model , mode = mode )
78
-
79
- if getattr (model , "quantizer" , False ):
80
- del model .quantizer
81
- else :
82
- model .quantizer = quantizer
81
+ postprocess_model (model , mode , quantizer )
83
82
return model
84
83
85
84
@@ -126,15 +125,11 @@ def gptq_entry(
126
125
)
127
126
kwargs .pop ("example_inputs" )
128
127
logger .warning ("lm_head in transformer model is skipped by GPTQ" )
129
- if getattr (model , "quantizer" , False ):
130
- quantizer = model .quantizer
131
- else :
132
- quantizer = GPTQuantizer (quant_config = weight_config )
128
+
129
+ quantizer = get_quantizer (model , quantizer_cls = GPTQuantizer , quant_config = weight_config )
133
130
model = quantizer .execute (model , mode = mode , * args , ** kwargs )
134
- if getattr (model , "quantizer" , False ):
135
- del model .quantizer
136
- else :
137
- model .quantizer = quantizer
131
+ postprocess_model (model , mode , quantizer )
132
+
138
133
return model
139
134
140
135
@@ -180,17 +175,10 @@ def static_quant_entry(
180
175
inplace = kwargs .get ("inplace" , True )
181
176
assert example_inputs is not None , "Please provide example_inputs for static quantization."
182
177
183
- if getattr (model , "quantizer" , False ):
184
- quantizer = model .quantizer
185
- else :
186
- quantizer = StaticQuantQuantizer (quant_config = quant_config_mapping )
187
-
178
+ quantizer = get_quantizer (model , quantizer_cls = StaticQuantQuantizer , quant_config = quant_config_mapping )
188
179
model = quantizer .execute (model , mode = mode , run_fn = run_fn , example_inputs = example_inputs , inplace = inplace )
180
+ postprocess_model (model , mode , quantizer )
189
181
190
- if getattr (model , "quantizer" , False ):
191
- del model .quantizer
192
- else :
193
- model .quantizer = quantizer
194
182
return model
195
183
196
184
@@ -323,11 +311,7 @@ def awq_quantize_entry(
323
311
example_inputs = kwargs .get ("example_inputs" , None )
324
312
assert example_inputs is not None , "Please provide example_inputs for AWQ quantization."
325
313
326
- if getattr (model , "quantizer" , False ):
327
- quantizer = model .quantizer
328
- else :
329
- quantizer = AWQQuantizer (quant_config = weight_config )
330
-
314
+ quantizer = get_quantizer (model , quantizer_cls = AWQQuantizer , quant_config = weight_config )
331
315
model = quantizer .execute (
332
316
model ,
333
317
mode = mode ,
@@ -340,11 +324,8 @@ def awq_quantize_entry(
340
324
return_int = return_int ,
341
325
use_full_range = use_full_range ,
342
326
)
327
+ postprocess_model (model , mode , quantizer )
343
328
344
- if getattr (model , "quantizer" , False ):
345
- del model .quantizer
346
- else :
347
- model .quantizer = quantizer
348
329
return model
349
330
350
331
@@ -386,10 +367,18 @@ def teq_quantize_entry(
386
367
absorb_to_layer = quant_config .absorb_to_layer
387
368
folding = quant_config .folding
388
369
assert isinstance (model , torch .nn .Module ), "only support torch module"
389
- quantizer = TEQuantizer (
390
- quant_config = weight_config , folding = folding , absorb_to_layer = absorb_to_layer , example_inputs = example_inputs
370
+
371
+ quantizer = get_quantizer (
372
+ model ,
373
+ quantizer_cls = TEQuantizer ,
374
+ quant_config = weight_config ,
375
+ folding = folding ,
376
+ absorb_to_layer = absorb_to_layer ,
377
+ example_inputs = example_inputs ,
391
378
)
392
379
model = quantizer .execute (model , mode = mode , run_fn = run_fn , example_inputs = example_inputs , inplace = inplace )
380
+ postprocess_model (model , mode , quantizer )
381
+
393
382
return model
394
383
395
384
@@ -436,35 +425,33 @@ def autoround_quantize_entry(
436
425
scale_dtype = quant_config .scale_dtype
437
426
438
427
kwargs .pop ("example_inputs" )
439
- if getattr ( model , "quantizer" , False ):
440
- quantizer = model . quantizer
441
- else :
442
- quantizer = AutoRoundQuantizer (
443
- weight_config = weight_config ,
444
- enable_full_range = enable_full_range ,
445
- batch_size = batch_size ,
446
- lr_scheduler = lr_scheduler ,
447
- use_quant_input = use_quant_input ,
448
- enable_minmax_tuning = enable_minmax_tuning ,
449
- lr = lr ,
450
- minmax_lr = minmax_lr ,
451
- low_gpu_mem_usage = low_gpu_mem_usage ,
452
- iters = iters ,
453
- seqlen = seqlen ,
454
- n_samples = n_samples ,
455
- sampler = sampler ,
456
- seed = seed ,
457
- n_blocks = n_blocks ,
458
- gradient_accumulate_steps = gradient_accumulate_steps ,
459
- not_use_best_mse = not_use_best_mse ,
460
- dynamic_max_gap = dynamic_max_gap ,
461
- scale_dtype = scale_dtype ,
462
- )
428
+
429
+ quantizer = get_quantizer (
430
+ model ,
431
+ quantizer_cls = AutoRoundQuantizer ,
432
+ quant_config = weight_config ,
433
+ enable_full_range = enable_full_range ,
434
+ batch_size = batch_size ,
435
+ lr_scheduler = lr_scheduler ,
436
+ use_quant_input = use_quant_input ,
437
+ enable_minmax_tuning = enable_minmax_tuning ,
438
+ lr = lr ,
439
+ minmax_lr = minmax_lr ,
440
+ low_gpu_mem_usage = low_gpu_mem_usage ,
441
+ iters = iters ,
442
+ seqlen = seqlen ,
443
+ n_samples = n_samples ,
444
+ sampler = sampler ,
445
+ seed = seed ,
446
+ n_blocks = n_blocks ,
447
+ gradient_accumulate_steps = gradient_accumulate_steps ,
448
+ not_use_best_mse = not_use_best_mse ,
449
+ dynamic_max_gap = dynamic_max_gap ,
450
+ scale_dtype = scale_dtype ,
451
+ )
463
452
model = quantizer .execute (model = model , mode = mode , * args , ** kwargs )
464
- if getattr (model , "quantizer" , False ):
465
- del model .quantizer
466
- else :
467
- model .quantizer = quantizer
453
+ postprocess_model (model , mode , quantizer )
454
+
468
455
logger .info ("AutoRound quantization done." )
469
456
return model
470
457
@@ -482,17 +469,11 @@ def hqq_entry(
482
469
from neural_compressor .torch .algorithms .weight_only .hqq import HQQuantizer
483
470
484
471
logger .info ("Quantize model with the HQQ algorithm." )
485
- if getattr (model , "quantizer" , False ):
486
- quantizer = model .quantizer
487
- else :
488
- quantizer = HQQuantizer (quant_config = configs_mapping )
489
472
473
+ quantizer = get_quantizer (model , quantizer_cls = HQQuantizer , quant_config = configs_mapping )
490
474
model = quantizer .execute (model , mode = mode )
475
+ postprocess_model (model , mode , quantizer )
491
476
492
- if getattr (model , "quantizer" , False ):
493
- del model .quantizer
494
- else :
495
- model .quantizer = quantizer
496
477
return model
497
478
498
479
0 commit comments