Skip to content

Commit 84d7055

Browse files
authored
Gptq refactor (#1770)
* refactor gptq with prepare and convert API Signed-off-by: xin3he <[email protected]> * fix bug Signed-off-by: xin3he <[email protected]> * update quantizer and model relationship Signed-off-by: xin3he <[email protected]> * fix bug Signed-off-by: xin3he <[email protected]> * add UT for quantize API Signed-off-by: xin3he <[email protected]> --------- Signed-off-by: xin3he <[email protected]>
1 parent 5f3f388 commit 84d7055

File tree

3 files changed

+139
-91
lines changed

3 files changed

+139
-91
lines changed

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 61 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def quantize(x, scale, zero, maxq):
183183
return scale * (q - zero)
184184

185185

186-
class GPTQuantizer(object):
186+
class RAWGPTQuantizer(object):
187187
"""Main API for GPTQ algorithm.
188188
189189
Please refer to:
@@ -195,15 +195,14 @@ def __init__(
195195
self,
196196
model,
197197
weight_config={},
198-
dataloader=None,
199198
nsamples=128,
200199
use_max_length=True,
201200
max_seq_length=2048,
202201
device=None,
203202
export_compressed_model=False,
204203
use_layer_wise=False,
205204
model_path="",
206-
run_fn=None,
205+
dataloader=None,
207206
*args,
208207
**kwargs,
209208
):
@@ -226,7 +225,6 @@ def __init__(
226225
export_compressed_model (bool, optional): Choose return fp32 or int32 model. Defaults to False.
227226
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
228227
model_path (str): Model path that is used to load state_dict per layer.
229-
run_fn: a function to run model inference for collecting input information.
230228
device: cpu or cuda
231229
"""
232230
# model
@@ -271,9 +269,7 @@ def __init__(
271269
self.dataloader_original = dataloader
272270
self.dataloader = []
273271
self.nsamples = nsamples
274-
self.run_fn = run_fn
275-
self.run_args = kwargs.get("run_args", None)
276-
if run_fn is None:
272+
if dataloader is not None:
277273
self.prepare_dataloader()
278274

279275
def prepare_dataloader(self):
@@ -489,7 +485,7 @@ def track_hidden_states(self, data):
489485
return data[0]
490486

491487
@torch.no_grad()
492-
def pre_quantization(self):
488+
def prepare_for_calibration(self):
493489
"""Prepare input calibration data and other attributes which are critical for gptq execution."""
494490
try:
495491
self.cache_key_arguments = {
@@ -532,34 +528,13 @@ def forward(layer, *args, **kwargs):
532528
# Step2: modify the first transformer block's forward function to obtain inputs for calibration
533529
if not self.use_layer_wise:
534530
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
535-
forward_cache = self.gptq_related_blocks["transformers"][0].forward
531+
self.forward_cache = self.gptq_related_blocks["transformers"][0].forward
536532
self.gptq_related_blocks["transformers"][0].forward = partial(
537533
forward, self.gptq_related_blocks["transformers"][0]
538534
)
539535

540-
# Step3: run forward to obtain calibration datasets
541-
logger.info("Collecting calibration inputs...")
542-
logger.info("Collecting calibration inputs by running the run_fn provided by user.")
543-
if self.run_fn:
544-
if self.run_args:
545-
self.run_fn(self.model, *self.run_args)
546-
accelerator.mark_step()
547-
else:
548-
self.run_fn(self.model)
549-
accelerator.mark_step()
550-
else:
551-
for batch in tqdm(self.dataloader):
552-
if not self.use_layer_wise:
553-
batch = move_input_to_device(batch, self.device)
554-
try:
555-
if isinstance(batch, tuple) or isinstance(batch, list):
556-
self.model(batch[0])
557-
elif isinstance(batch, dict):
558-
self.model(**batch)
559-
else:
560-
self.model(batch)
561-
except ValueError:
562-
pass
536+
@torch.no_grad()
537+
def remove_prepare_for_calibration(self):
563538
# output inp data shape
564539
logger.info("All calibration data's shape =>")
565540
# check all hidden_states shape
@@ -571,7 +546,7 @@ def forward(layer, *args, **kwargs):
571546
logger.info("Done.")
572547

573548
# Step 4: restore original forward function, relocate layers back to cpu.
574-
self.gptq_related_blocks["transformers"][0].forward = forward_cache
549+
self.gptq_related_blocks["transformers"][0].forward = self.forward_cache
575550
if not self.use_layer_wise:
576551
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
577552
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
@@ -606,7 +581,6 @@ def execute_quantization(self, means=None, stds=None):
606581
# Step1: prepare quantization (calibration datasets)
607582

608583
logger.info("Begin ====>")
609-
self.pre_quantization()
610584
model_path = self.model_path
611585

612586
# Step2: run gptq quantization in a transformer block-wise manner.
@@ -1144,41 +1118,57 @@ def ready(self):
11441118
return torch.all(self.scale != 0)
11451119

11461120

1147-
def gptq_quantize(
1148-
model,
1149-
weight_config={},
1150-
dataloader=None,
1151-
nsamples=128,
1152-
max_seq_length=2048,
1153-
use_max_length=True,
1154-
device=None,
1155-
export_compressed_model=False,
1156-
use_layer_wise=False,
1157-
model_path=None,
1158-
run_fn=None,
1159-
run_args=None,
1160-
):
1161-
"""Run weight-only quantization with."""
1162-
# TODO: unify weight_config keys, add docstring, and support default config
1163-
assert isinstance(model, torch.nn.Module), "only support torch module"
1164-
if use_layer_wise:
1165-
assert model_path is not None, "model_path should not be None when use layer wise mode"
1166-
from .gptq import GPTQuantizer
1167-
1168-
gptq_quantizer = GPTQuantizer(
1121+
from neural_compressor.torch.algorithms import Quantizer as INCQuantizer
1122+
1123+
1124+
class GPTQuantizer(INCQuantizer):
1125+
def __init__(self, quant_config={}):
1126+
"""Init a RTNQuantizer object.
1127+
1128+
Args:
1129+
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
1130+
"""
1131+
super().__init__(quant_config)
1132+
1133+
@torch.no_grad()
1134+
def prepare(
1135+
self,
11691136
model,
1170-
weight_config,
1171-
dataloader,
1172-
nsamples,
1173-
use_max_length,
1174-
max_seq_length,
1175-
device,
1176-
export_compressed_model=export_compressed_model,
1177-
use_layer_wise=use_layer_wise,
1178-
model_path=model_path,
1179-
run_fn=run_fn,
1180-
run_args=run_args,
1181-
)
1182-
fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization()
1183-
logger.info("GPTQ quantizing done.")
1184-
return fp32_modified_model, gptq_config
1137+
nsamples=128,
1138+
max_seq_length=2048,
1139+
use_max_length=True,
1140+
device=None,
1141+
export_compressed_model=False,
1142+
use_layer_wise=False,
1143+
model_path=None,
1144+
*args,
1145+
**kwargs,
1146+
):
1147+
"""Run weight-only quantization with."""
1148+
# TODO: unify weight_config keys, add docstring, and support default config
1149+
assert isinstance(model, torch.nn.Module), "only support torch module"
1150+
if use_layer_wise:
1151+
assert model_path is not None, "model_path should not be None when use layer wise mode"
1152+
1153+
self.gptq_quantizer = RAWGPTQuantizer(
1154+
model,
1155+
weight_config=self.quant_config,
1156+
nsamples=nsamples,
1157+
use_max_length=use_max_length,
1158+
max_seq_length=max_seq_length,
1159+
device=device,
1160+
export_compressed_model=export_compressed_model,
1161+
use_layer_wise=use_layer_wise,
1162+
model_path=model_path,
1163+
)
1164+
self.gptq_quantizer.prepare_for_calibration()
1165+
return self.gptq_quantizer.model
1166+
1167+
@torch.no_grad()
1168+
def convert(self, model, *args, **kwargs):
1169+
self.gptq_quantizer.model = model
1170+
self.gptq_quantizer.remove_prepare_for_calibration()
1171+
q_model, gptq_config = self.gptq_quantizer.execute_quantization()
1172+
q_model.gptq_config = gptq_config
1173+
logger.info("GPTQ quantizing done.")
1174+
return q_model

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,14 @@ def rtn_entry(
7272
@register_algo(GPTQ)
7373
@torch.no_grad()
7474
def gptq_entry(
75-
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs
75+
model: torch.nn.Module,
76+
configs_mapping: Dict[Tuple[str, callable], GPTQConfig],
77+
mode: Mode = Mode.QUANTIZE,
78+
*args,
79+
**kwargs,
7680
) -> torch.nn.Module:
7781
logger.info("Quantize model with the GPTQ algorithm.")
78-
from neural_compressor.torch.algorithms.weight_only.gptq import gptq_quantize
82+
from neural_compressor.torch.algorithms.weight_only.gptq import GPTQuantizer
7983

8084
# rebuild weight_config for gptq_quantize function
8185
weight_config = {}
@@ -106,12 +110,16 @@ def gptq_entry(
106110
}
107111
)
108112
kwargs.pop("example_inputs")
109-
kwargs.pop("mode") # TODO: will be removed after GPTQ refactoring
110-
111113
logger.warning("lm_head in transformer model is skipped by GPTQ")
112-
model, quantization_perm = gptq_quantize(model=model, weight_config=weight_config, *args, **kwargs)
113-
# Assign the gptq config as an attribute of model
114-
model._gptq_quantization_perm = quantization_perm
114+
if getattr(model, "quantizer", False):
115+
quantizer = model.quantizer
116+
else:
117+
quantizer = GPTQuantizer(quant_config=weight_config)
118+
model = quantizer.execute(model, mode=mode, *args, **kwargs)
119+
if getattr(model, "quantizer", False):
120+
del model.quantizer
121+
else:
122+
model.quantizer = quantizer
115123
return model
116124

117125

@@ -123,7 +131,7 @@ def static_quant_entry(
123131
configs_mapping: Dict[Tuple[str, callable], StaticQuantConfig],
124132
mode: Mode = Mode.QUANTIZE,
125133
*args,
126-
**kwargs
134+
**kwargs,
127135
) -> torch.nn.Module:
128136
logger.info("Quantize model with the static quant algorithm.")
129137
from neural_compressor.torch.algorithms.static_quant import StaticQuantQuantizer
@@ -333,7 +341,7 @@ def autoround_quantize_entry(
333341
configs_mapping: Dict[Tuple[str, callable], AutoRoundConfig],
334342
mode: Mode = Mode.QUANTIZE,
335343
*args,
336-
**kwargs
344+
**kwargs,
337345
) -> torch.nn.Module:
338346
from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer
339347

0 commit comments

Comments
 (0)