Skip to content

Commit ec49a29

Browse files
Enhance 3.x torch algorithm entry (#1779)
Enhance 3.x torch algorithm entry --------- Signed-off-by: yuwenzho <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 43c3580 commit ec49a29

File tree

7 files changed

+111
-90
lines changed

7 files changed

+111
-90
lines changed

neural_compressor/torch/algorithms/base_algorithm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,13 @@ def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
9999

100100
return model
101101

102-
def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover
102+
def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any):
103103
"""Execute according to mode.
104104
105105
Args:
106106
model (torch.nn.Module): The model to be executed.
107107
mode (Mode): The mode of current phase, including 'prepare', 'convert' and 'quantize'.
108108
"""
109-
# TODO: remove '# pragma: no cover' once CI test can cover this function
110109
if mode == Mode.PREPARE:
111110
model = self.prepare(model, *args, **kwargs)
112111
elif mode == Mode.CONVERT:

neural_compressor/torch/algorithms/weight_only/autoround.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
class AutoRoundQuantizer(Quantizer):
2727
def __init__(
2828
self,
29-
weight_config: dict = {},
29+
quant_config: dict = None,
3030
enable_full_range: bool = False,
3131
batch_size: int = 8,
3232
amp: bool = True,
@@ -51,8 +51,8 @@ def __init__(
5151
"""Init a AutQRoundQuantizer object.
5252
5353
Args:
54-
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
55-
weight_config={
54+
quant_config (dict): Configuration for weight quantization (default is None).
55+
quant_config={
5656
'layer1':##layer_name
5757
{
5858
'data_type': 'int',
@@ -89,9 +89,8 @@ def __init__(
8989
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
9090
have different choices.
9191
"""
92-
super().__init__(weight_config)
92+
super().__init__(quant_config)
9393
self.tokenizer = None
94-
self.weight_config = weight_config
9594
self.enable_full_range = enable_full_range
9695
self.batch_size = batch_size
9796
self.amp = amp
@@ -125,7 +124,7 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
125124
self.rounder = AutoRoundProcessor(
126125
model=model,
127126
tokenizer=None,
128-
weight_config=self.weight_config,
127+
weight_config=self.quant_config or {},
129128
enable_full_range=self.enable_full_range,
130129
batch_size=self.batch_size,
131130
amp=self.amp,

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 56 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@
3030
StaticQuantConfig,
3131
TEQConfig,
3232
)
33-
from neural_compressor.torch.utils import Mode, is_ipex_imported, logger, register_algo
33+
from neural_compressor.torch.utils import (
34+
Mode,
35+
get_quantizer,
36+
is_ipex_imported,
37+
logger,
38+
postprocess_model,
39+
register_algo,
40+
)
3441
from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT
3542

3643

@@ -69,17 +76,9 @@ def rtn_entry(
6976
"double_quant_group_size": quant_config.double_quant_group_size,
7077
}
7178

72-
if getattr(model, "quantizer", False):
73-
quantizer = model.quantizer
74-
else:
75-
quantizer = RTNQuantizer(quant_config=weight_config)
76-
79+
quantizer = get_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config)
7780
model = quantizer.execute(model, mode=mode)
78-
79-
if getattr(model, "quantizer", False):
80-
del model.quantizer
81-
else:
82-
model.quantizer = quantizer
81+
postprocess_model(model, mode, quantizer)
8382
return model
8483

8584

@@ -126,15 +125,11 @@ def gptq_entry(
126125
)
127126
kwargs.pop("example_inputs")
128127
logger.warning("lm_head in transformer model is skipped by GPTQ")
129-
if getattr(model, "quantizer", False):
130-
quantizer = model.quantizer
131-
else:
132-
quantizer = GPTQuantizer(quant_config=weight_config)
128+
129+
quantizer = get_quantizer(model, quantizer_cls=GPTQuantizer, quant_config=weight_config)
133130
model = quantizer.execute(model, mode=mode, *args, **kwargs)
134-
if getattr(model, "quantizer", False):
135-
del model.quantizer
136-
else:
137-
model.quantizer = quantizer
131+
postprocess_model(model, mode, quantizer)
132+
138133
return model
139134

140135

@@ -180,17 +175,10 @@ def static_quant_entry(
180175
inplace = kwargs.get("inplace", True)
181176
assert example_inputs is not None, "Please provide example_inputs for static quantization."
182177

183-
if getattr(model, "quantizer", False):
184-
quantizer = model.quantizer
185-
else:
186-
quantizer = StaticQuantQuantizer(quant_config=quant_config_mapping)
187-
178+
quantizer = get_quantizer(model, quantizer_cls=StaticQuantQuantizer, quant_config=quant_config_mapping)
188179
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
180+
postprocess_model(model, mode, quantizer)
189181

190-
if getattr(model, "quantizer", False):
191-
del model.quantizer
192-
else:
193-
model.quantizer = quantizer
194182
return model
195183

196184

@@ -323,11 +311,7 @@ def awq_quantize_entry(
323311
example_inputs = kwargs.get("example_inputs", None)
324312
assert example_inputs is not None, "Please provide example_inputs for AWQ quantization."
325313

326-
if getattr(model, "quantizer", False):
327-
quantizer = model.quantizer
328-
else:
329-
quantizer = AWQQuantizer(quant_config=weight_config)
330-
314+
quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config)
331315
model = quantizer.execute(
332316
model,
333317
mode=mode,
@@ -340,11 +324,8 @@ def awq_quantize_entry(
340324
return_int=return_int,
341325
use_full_range=use_full_range,
342326
)
327+
postprocess_model(model, mode, quantizer)
343328

344-
if getattr(model, "quantizer", False):
345-
del model.quantizer
346-
else:
347-
model.quantizer = quantizer
348329
return model
349330

350331

@@ -386,10 +367,18 @@ def teq_quantize_entry(
386367
absorb_to_layer = quant_config.absorb_to_layer
387368
folding = quant_config.folding
388369
assert isinstance(model, torch.nn.Module), "only support torch module"
389-
quantizer = TEQuantizer(
390-
quant_config=weight_config, folding=folding, absorb_to_layer=absorb_to_layer, example_inputs=example_inputs
370+
371+
quantizer = get_quantizer(
372+
model,
373+
quantizer_cls=TEQuantizer,
374+
quant_config=weight_config,
375+
folding=folding,
376+
absorb_to_layer=absorb_to_layer,
377+
example_inputs=example_inputs,
391378
)
392379
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
380+
postprocess_model(model, mode, quantizer)
381+
393382
return model
394383

395384

@@ -436,35 +425,33 @@ def autoround_quantize_entry(
436425
scale_dtype = quant_config.scale_dtype
437426

438427
kwargs.pop("example_inputs")
439-
if getattr(model, "quantizer", False):
440-
quantizer = model.quantizer
441-
else:
442-
quantizer = AutoRoundQuantizer(
443-
weight_config=weight_config,
444-
enable_full_range=enable_full_range,
445-
batch_size=batch_size,
446-
lr_scheduler=lr_scheduler,
447-
use_quant_input=use_quant_input,
448-
enable_minmax_tuning=enable_minmax_tuning,
449-
lr=lr,
450-
minmax_lr=minmax_lr,
451-
low_gpu_mem_usage=low_gpu_mem_usage,
452-
iters=iters,
453-
seqlen=seqlen,
454-
n_samples=n_samples,
455-
sampler=sampler,
456-
seed=seed,
457-
n_blocks=n_blocks,
458-
gradient_accumulate_steps=gradient_accumulate_steps,
459-
not_use_best_mse=not_use_best_mse,
460-
dynamic_max_gap=dynamic_max_gap,
461-
scale_dtype=scale_dtype,
462-
)
428+
429+
quantizer = get_quantizer(
430+
model,
431+
quantizer_cls=AutoRoundQuantizer,
432+
quant_config=weight_config,
433+
enable_full_range=enable_full_range,
434+
batch_size=batch_size,
435+
lr_scheduler=lr_scheduler,
436+
use_quant_input=use_quant_input,
437+
enable_minmax_tuning=enable_minmax_tuning,
438+
lr=lr,
439+
minmax_lr=minmax_lr,
440+
low_gpu_mem_usage=low_gpu_mem_usage,
441+
iters=iters,
442+
seqlen=seqlen,
443+
n_samples=n_samples,
444+
sampler=sampler,
445+
seed=seed,
446+
n_blocks=n_blocks,
447+
gradient_accumulate_steps=gradient_accumulate_steps,
448+
not_use_best_mse=not_use_best_mse,
449+
dynamic_max_gap=dynamic_max_gap,
450+
scale_dtype=scale_dtype,
451+
)
463452
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
464-
if getattr(model, "quantizer", False):
465-
del model.quantizer
466-
else:
467-
model.quantizer = quantizer
453+
postprocess_model(model, mode, quantizer)
454+
468455
logger.info("AutoRound quantization done.")
469456
return model
470457

@@ -482,17 +469,11 @@ def hqq_entry(
482469
from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer
483470

484471
logger.info("Quantize model with the HQQ algorithm.")
485-
if getattr(model, "quantizer", False):
486-
quantizer = model.quantizer
487-
else:
488-
quantizer = HQQuantizer(quant_config=configs_mapping)
489472

473+
quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
490474
model = quantizer.execute(model, mode=mode)
475+
postprocess_model(model, mode, quantizer)
491476

492-
if getattr(model, "quantizer", False):
493-
del model.quantizer
494-
else:
495-
model.quantizer = quantizer
496477
return model
497478

498479

neural_compressor/torch/quantization/quantize.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def prepare(
9191
quant_config: BaseConfig,
9292
inplace: bool = True,
9393
example_inputs: Any = None,
94-
): # pragma: no cover
94+
):
9595
"""Prepare the model for calibration.
9696
9797
Insert observers into the model so that it can monitor the input and output tensors during calibration.
@@ -105,7 +105,6 @@ def prepare(
105105
Returns:
106106
prepared and calibrated module.
107107
"""
108-
# TODO: remove '# pragma: no cover' once CI test can cover this function
109108
prepared_model = model if inplace else copy.deepcopy(model)
110109
registered_configs = config_registry.get_cls_configs()
111110
if isinstance(quant_config, dict):
@@ -148,7 +147,7 @@ def convert(
148147
model: torch.nn.Module,
149148
quant_config: BaseConfig = None,
150149
inplace: bool = True,
151-
): # pragma: no cover
150+
):
152151
"""Convert the prepared model to a quantized model.
153152
154153
Args:
@@ -159,7 +158,6 @@ def convert(
159158
Returns:
160159
The quantized model.
161160
"""
162-
# TODO: remove '# pragma: no cover' once CI test can cover this function
163161
q_model = model if inplace else copy.deepcopy(model)
164162

165163
# TODO: Optimize the check for prepared flag after adding HQT FP8 Quant

neural_compressor/torch/utils/utility.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,47 @@ class Mode(Enum):
137137
QUANTIZE = "quantize"
138138

139139

140+
def get_quantizer(model, quantizer_cls, quant_config=None, *args, **kwargs):
141+
"""Get the quantizer.
142+
143+
Initialize a quantizer or get `quantizer` attribute from model.
144+
145+
Args:
146+
model (torch.nn.Module): pytorch model.
147+
quantizer_cls (Quantizer): quantizer class of a specific algorithm.
148+
quant_config (dict, optional): Specifies how to apply the algorithm on the given model.
149+
Defaults to None.
150+
151+
Returns:
152+
quantizer object.
153+
"""
154+
if not hasattr(model, "quantizer"):
155+
quantizer = quantizer_cls(quant_config=quant_config, *args, **kwargs)
156+
return quantizer
157+
else:
158+
return model.quantizer
159+
160+
161+
def postprocess_model(model, mode, quantizer):
162+
"""Process `quantizer` attribute of model according to current phase.
163+
164+
In `prepare` phase, the `quantizer` is set as an attribute of the model
165+
to avoid redundant initialization during `convert` phase.
166+
167+
In 'convert' or 'quantize' phase, the unused `quantizer` attribute is removed.
168+
169+
Args:
170+
model (torch.nn.Module): pytorch model.
171+
mode (Mode): The mode of current phase, including 'prepare', 'convert' and 'quantize'.
172+
quantizer (Quantizer): quantizer object.
173+
"""
174+
if mode == Mode.PREPARE:
175+
model.quantizer = quantizer
176+
elif mode == Mode.CONVERT or mode == Mode.QUANTIZE:
177+
if getattr(model, "quantizer", False):
178+
del model.quantizer
179+
180+
140181
def create_quant_spec_from_config(dtype, sym, granularity, algo) -> QuantizationSpec:
141182
dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
142183
qscheme_mapping = {

test/3x/torch/quantization/weight_only/test_autoround.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_quantizer(self):
102102
"sym": False,
103103
}
104104
}
105-
quantizer = AutoRoundQuantizer(weight_config=weight_config)
105+
quantizer = AutoRoundQuantizer(quant_config=weight_config)
106106
fp32_model = gpt_j_model
107107

108108
# quantizer execute

test/3x/torch/quantization/weight_only/test_mixed_algos.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111

1212
def run_fn(model):
13-
model(torch.tensor([[10, 20, 30]], dtype=torch.long))
14-
model(torch.tensor([[40, 50, 60]], dtype=torch.long))
13+
# GPTQ uses ValueError to reduce computation when collecting input data of the first block
14+
# It's special for UTs, no need to add this wrapper in examples.
15+
with pytest.raises(ValueError):
16+
model(torch.tensor([[10, 20, 30]], dtype=torch.long))
17+
model(torch.tensor([[40, 50, 60]], dtype=torch.long))
1518

1619

1720
class TestMixedTwoAlgo:

0 commit comments

Comments
 (0)