Skip to content

Commit 1c426a0

Browse files
authored
Support static_groups options in GPTQ API (#1478)
Signed-off-by: YIYANGCAI <[email protected]>
1 parent ab72037 commit 1c426a0

File tree

5 files changed

+38
-6
lines changed

5 files changed

+38
-6
lines changed

docs/source/quantization_weight_only.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ Notes:
8686
| pad_max_length | 2048 | Whether to align calibration data to a fixed length. This value should not exceed model's acceptable sequence length. Please refer to model's config json to find out this value.|
8787
| use_max_length | False | Whether to align all calibration data to fixed length, which equals to pad_max_length. |
8888
| block_size | 128 | Execute GPTQ quantization per block, block shape = [$C_{out}$, block_size] |
89+
| static_groups | False | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements |
8990

9091
**Note:** Neural compressor provides `Unsigned integer for asymmetric quantization` and `Signed integer for symmetric quantization`. Please follow the below section to compress the low bit data type for saving.
9192

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
this should align with your model config, \
7979
and your dataset builder args: args.pad_max_length')
8080
parser.add_argument('--gptq_debug', action='store_true', help='Whether to use debug model ')
81+
parser.add_argument('--gptq_static_groups', action='store_true', help='Use determined group to do quantization')
8182
# ==============code generation args===========
8283
parser.add_argument("--code_generation", action="store_true")
8384
parser.add_argument("--n_samples", default=200, type=int)
@@ -277,7 +278,8 @@ def calib_func(prepared_model):
277278
'block_size': args.gptq_block_size,
278279
'nsamples': args.gptq_nsamples,
279280
'use_max_length': args.gptq_use_max_length,
280-
'pad_max_length': args.gptq_pad_max_length
281+
'pad_max_length': args.gptq_pad_max_length,
282+
'static_groups': args.gptq_static_groups,
281283
}
282284
# GPTQ: use assistive functions to modify calib_dataloader and calib_func
283285
# TEQ: set calib_func=None, use default training func as calib_func
@@ -293,6 +295,7 @@ def calib_func(prepared_model):
293295

294296
# for test on various models, keep the code of directly call gptq_quantize
295297
if args.gptq_debug:
298+
296299
from neural_compressor.adaptor.torch_utils.weight_only import gptq_quantize
297300

298301
gptq_conf = {
@@ -301,6 +304,7 @@ def calib_func(prepared_model):
301304
'group_size': args.woq_group_size, # -1 (per-channel)
302305
'sym': (args.woq_scheme == "sym"),
303306
'act_order': args.gptq_actorder,
307+
'static_groups': args.gptq_static_groups,
304308
}
305309
}
306310
q_model_gptq_debug, gptq_config = gptq_quantize(

neural_compressor/adaptor/pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4709,6 +4709,7 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
47094709
"percdamp": self.recipes["gptq_args"].get("percdamp", 0.01),
47104710
"act_order": self.recipes["gptq_args"].get("act_order", False),
47114711
"block_size": self.recipes["gptq_args"].get("block_size", True),
4712+
"static_groups": self.recipes["gptq_args"].get("static_groups", False),
47124713
}
47134714
nsamples = self.recipes["gptq_args"].get("nsamples", 128)
47144715
use_max_length = self.recipes["gptq_args"].get("use_max_length", False)

neural_compressor/adaptor/torch_utils/gptq.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def __init__(
232232
self.percdamp_default = 0.01
233233
self.sym_default = False
234234
self.act_order_default = False
235+
self.static_groups_default = False
235236
self.perchannel_default = True
236237
self.mse_default = False
237238
self.check_layer_config()
@@ -406,6 +407,9 @@ def check_layer_config(self):
406407
tmp_weight_config[name]["percdamp"] = self.weight_config.get("pecdamp", self.percdamp_default)
407408
tmp_weight_config[name]["sym"] = self.weight_config.get("sym", self.sym_default)
408409
tmp_weight_config[name]["act_order"] = self.weight_config.get("act_order", self.act_order_default)
410+
tmp_weight_config[name]["static_groups"] = self.weight_config.get(
411+
"static_groups", self.static_groups_default
412+
)
409413
tmp_weight_config[name]["perchannel"] = self.weight_config.get("perchannel", self.perchannel_default)
410414
tmp_weight_config[name]["mse"] = self.weight_config.get("mse", self.mse_default)
411415
self.weight_config = tmp_weight_config
@@ -417,6 +421,9 @@ def check_layer_config(self):
417421
self.weight_config[layer_name]["percdamp"] = config.get("pecdamp", self.percdamp_default)
418422
self.weight_config[layer_name]["sym"] = config.get("sym", self.sym_default)
419423
self.weight_config[layer_name]["act_order"] = config.get("act_order", self.act_order_default)
424+
self.weight_config[layer_name]["static_groups"] = config.get(
425+
"static_groups", self.static_groups_default
426+
)
420427
self.weight_config[layer_name]["perchannel"] = config.get("perchannel", self.perchannel_default)
421428
self.weight_config[layer_name]["mse"] = config.get("mse", self.mse_default)
422429

@@ -631,6 +638,7 @@ def tmp(_, inp, out):
631638
percdamp=weight_config_this_layer["percdamp"],
632639
groupsize=weight_config_this_layer["group_size"],
633640
act_order=weight_config_this_layer["act_order"],
641+
static_groups=weight_config_this_layer["static_groups"],
634642
)
635643
if self.layer_wise:
636644
from ..torch_utils.layer_wise_quant.utils import (
@@ -745,7 +753,7 @@ def add_batch(self, inp, out):
745753
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
746754
self.H += inp.matmul(inp.t()) # H = X*X, which should be a sysm matrix
747755

748-
def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False):
756+
def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, static_groups=False):
749757
# W = self.layer.weight.data.clone()
750758
weight_shape, weight_dtype = W.shape, W.data.dtype
751759
if isinstance(self.layer, nn.Conv2d):
@@ -765,6 +773,17 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
765773
H[dead, dead] = 1
766774
W[:, dead] = 0 # such channel makes no contribution to quantization computation
767775

776+
# enable static_groups
777+
# calculate the quantization parameters for original group in advance.
778+
if static_groups:
779+
import copy
780+
781+
groups = []
782+
for i in range(0, self.columns, groupsize):
783+
quantizer = copy.deepcopy(self.quantizer)
784+
quantizer.find_params(W[:, i : (i + groupsize)], weight=True)
785+
groups.append(quantizer)
786+
768787
# rearrange considering the diag's value
769788
if act_order:
770789
perm = torch.argsort(torch.diag(H), descending=True)
@@ -801,10 +820,16 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
801820
d = Hinv1[i, i]
802821

803822
if groupsize != -1:
804-
if (i1 + i) % groupsize == 0:
805-
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + groupsize)], weight=True)
806-
scale.append(self.quantizer.scale)
807-
zero.append(self.quantizer.zero)
823+
if not static_groups:
824+
if (i1 + i) % groupsize == 0:
825+
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + groupsize)], weight=True)
826+
scale.append(self.quantizer.scale)
827+
zero.append(self.quantizer.zero)
828+
else:
829+
idx = i1 + i
830+
if act_order:
831+
idx = perm[idx]
832+
self.quantizer = groups[idx // groupsize]
808833

809834
q = quantize(w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq).flatten()
810835
Q1[:, i] = q

test/quantization/test_weight_only_quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __iter__(self):
147147
"sym": False,
148148
"percdamp": 0.01,
149149
"act_order": True,
150+
"static_groups": True,
150151
},
151152
"transformer.h.2.attn.k_proj": {
152153
"wbits": 3,

0 commit comments

Comments
 (0)