Skip to content

Commit fd96851

Browse files
authored
Integrate AutoRound v0.3 to 2x (#1926)
Signed-off-by: Kaihui-intel <[email protected]>
1 parent bfa27e4 commit fd96851

File tree

7 files changed

+64
-33
lines changed

7 files changed

+64
-33
lines changed

.azure-pipelines/scripts/ut/env_setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
9292
fi
9393

9494
if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
95-
pip install auto-round
95+
pip install git+https://github.com/intel/auto-round.git@24b2e74070f2b4e6f26ff069ec75af74cf5b177c
9696
fi
9797

9898
# test deps

neural_compressor/adaptor/pytorch.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4905,13 +4905,13 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
49054905
enable_minmax_tuning = self.recipes["autoround_args"].get("enable_minmax_tuning", True)
49064906
lr = self.recipes["autoround_args"].get("lr", None)
49074907
minmax_lr = self.recipes["autoround_args"].get("minmax_lr", None)
4908-
low_gpu_mem_usage = self.recipes["autoround_args"].get("low_gpu_mem_usage", True)
4908+
low_gpu_mem_usage = self.recipes["autoround_args"].get("low_gpu_mem_usage", False)
49094909
iters = self.recipes["autoround_args"].get("iters", 200)
49104910
seqlen = self.recipes["autoround_args"].get("seqlen", 2048)
4911-
n_samples = self.recipes["autoround_args"].get("n_samples", 512)
4911+
nsamples = self.recipes["autoround_args"].get("nsamples", 128)
49124912
sampler = self.recipes["autoround_args"].get("sampler", "rand")
49134913
seed = self.recipes["autoround_args"].get("seed", 42)
4914-
n_blocks = self.recipes["autoround_args"].get("n_blocks", 1)
4914+
nblocks = self.recipes["autoround_args"].get("nblocks", 1)
49154915
gradient_accumulate_steps = self.recipes["autoround_args"].get("gradient_accumulate_steps", 1)
49164916
not_use_best_mse = self.recipes["autoround_args"].get("not_use_best_mse", False)
49174917
dynamic_max_gap = self.recipes["autoround_args"].get("dynamic_max_gap", -1)
@@ -4922,6 +4922,12 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
49224922
bits = self.recipes["autoround_args"].get("bits", 4)
49234923
group_size = self.recipes["autoround_args"].get("group_size", 128)
49244924
sym = self.recipes["autoround_args"].get("scheme", "asym") == "sym"
4925+
act_bits = self.recipes["autoround_args"].get("act_bits", 32)
4926+
act_group_size = self.recipes["autoround_args"].get("act_group_size", None)
4927+
act_sym = self.recipes["autoround_args"].get("act_sym", None)
4928+
act_dynamic = self.recipes["autoround_args"].get("act_dynamic", True)
4929+
multimodal = self.recipes["autoround_args"].get("multimodal", False)
4930+
use_layer_wise = self.recipes["autoround_args"].get("use_layer_wise", False)
49254931

49264932
if dataloader is not None:
49274933
dataset = dataloader
@@ -4944,15 +4950,21 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
49444950
low_gpu_mem_usage=low_gpu_mem_usage,
49454951
iters=iters,
49464952
seqlen=seqlen,
4947-
n_samples=n_samples,
4953+
nsamples=nsamples,
49484954
sampler=sampler,
49494955
seed=seed,
4950-
n_blocks=n_blocks,
4956+
nblocks=nblocks,
49514957
gradient_accumulate_steps=gradient_accumulate_steps,
49524958
not_use_best_mse=not_use_best_mse,
49534959
dynamic_max_gap=dynamic_max_gap,
49544960
data_type=data_type,
49554961
scale_dtype=scale_dtype,
4962+
multimodal=multimodal,
4963+
act_bits=act_bits,
4964+
act_group_size=act_group_size,
4965+
act_sym=act_sym,
4966+
act_dynamic=act_dynamic,
4967+
use_layer_wise=use_layer_wise,
49564968
)
49574969
return model, autoround_config
49584970

neural_compressor/adaptor/torch_utils/auto_round.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512):
16+
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=128):
1717
"""Generate a DataLoader for calibration using specified parameters.
1818
1919
Args:
@@ -25,14 +25,12 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
2525
split (str, optional): The data split to use. Defaults to None.
2626
seed (int, optional): The random seed for reproducibility. Defaults to 42.
2727
bs (int, optional): The batch size. Defaults to 4.
28-
n_samples (int, optional): The total number of samples to include. Defaults to 512.
28+
nsamples (int, optional): The total number of samples to include. Defaults to 128.
2929
3030
Returns:
3131
DataLoader: The DataLoader for the calibrated dataset.
3232
"""
3333
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401
3434

35-
dataloader = get_dataloader(
36-
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples
37-
)
35+
dataloader = get_dataloader(tokenizer, seqlen, dataset_name=dataset_name, seed=seed, bs=bs, nsamples=nsamples)
3836
return dataloader

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -694,21 +694,28 @@ def autoround_quantize(
694694
enable_minmax_tuning: bool = True,
695695
lr: float = None,
696696
minmax_lr: float = None,
697-
low_gpu_mem_usage: bool = True,
697+
low_gpu_mem_usage: bool = False,
698698
iters: int = 200,
699699
seqlen: int = 2048,
700-
n_samples: int = 512,
700+
nsamples: int = 128,
701701
sampler: str = "rand",
702702
seed: int = 42,
703-
n_blocks: int = 1,
703+
nblocks: int = 1,
704704
gradient_accumulate_steps: int = 1,
705705
not_use_best_mse: bool = False,
706706
dynamic_max_gap: int = -1,
707707
data_type: str = "int", ##only support int for now
708708
scale_dtype: str = "fp16",
709+
multimodal: bool = False,
710+
act_bits: int = 32,
711+
act_group_size: int = None,
712+
act_sym: bool = None,
713+
act_dynamic: bool = True,
714+
use_layer_wise: bool = False,
709715
**kwargs,
710716
):
711717
"""Run autoround weight-only quantization.
718+
712719
Args:
713720
model: The PyTorch model to be quantized.
714721
tokenizer: An optional tokenizer for processing input data. If none is provided, a dataloader must be supplied.
@@ -717,15 +724,19 @@ def autoround_quantize(
717724
sym (bool): Whether symmetric quantization is to be used (default is False).
718725
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
719726
weight_config={
720-
'layer1':##layer_name
721-
{
722-
'data_type': 'int',
723-
'bits': 4,
724-
'group_size': 32,
725-
'sym': False
726-
}
727-
...
728-
}
727+
'layer1':##layer_name
728+
{
729+
'data_type': 'int',
730+
'bits': 4,
731+
'group_size': 32,
732+
'sym': False,
733+
'act_data_type': None,
734+
'act_bits': 32,
735+
'act_sym': None,
736+
'act_dynamic': True,
737+
}
738+
...,
739+
}
729740
enable_full_range (bool): Whether to enable full range quantization (default is False).
730741
batch_size (int): Batch size for training (default is 8).
731742
amp (bool): Whether to use automatic mixed precision (default is True).
@@ -737,20 +748,24 @@ def autoround_quantize(
737748
enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True).
738749
lr (float): The learning rate (default is None, will be set to 1.0/iters).
739750
minmax_lr (float): The learning rate for min-max tuning (default is None, it will be set to lr automatically).
740-
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
751+
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False).
741752
iters (int): Number of iterations (default is 200).
742753
seqlen (int): Data length of the sequence for tuning (default is 2048).
743-
n_samples (int): Number of samples (default is 512).
754+
nsamples (int): Number of samples (default is 128).
744755
sampler (str): The sampling method (default is "rand").
745756
seed (int): The random seed (default is 42).
746-
n_blocks (int): Number of blocks (default is 1).
757+
nblocks (int): Number of blocks (default is 1).
747758
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
748759
not_use_best_mse (bool): Whether to use mean squared error (default is False).
749760
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
750761
data_type (str): The data type to be used (default is "int").
751762
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
752763
have different choices.
753-
764+
multimodal(bool): Enable multimodal model quantization, (default is "False").
765+
act_bits (int): Number of bits for activation quantization. Default is 32.
766+
act_group_size (int): Group size for activation quantization. Default is None.
767+
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
768+
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
754769
Returns:
755770
The quantized model.
756771
"""
@@ -762,7 +777,7 @@ def autoround_quantize(
762777
bits=bits,
763778
group_size=group_size,
764779
sym=sym,
765-
weight_config=weight_config,
780+
layer_config=weight_config,
766781
enable_full_range=enable_full_range, ##for symmetric, TODO support later
767782
batch_size=batch_size,
768783
amp=amp,
@@ -776,15 +791,21 @@ def autoround_quantize(
776791
low_gpu_mem_usage=low_gpu_mem_usage,
777792
iters=iters,
778793
seqlen=seqlen,
779-
n_samples=n_samples,
794+
nsamples=nsamples,
780795
sampler=sampler,
781796
seed=seed,
782-
n_blocks=n_blocks,
797+
nblocks=nblocks,
783798
gradient_accumulate_steps=gradient_accumulate_steps,
784799
not_use_best_mse=not_use_best_mse,
785800
dynamic_max_gap=dynamic_max_gap,
786801
data_type=data_type, ## only support data_type
787802
scale_dtype=scale_dtype,
803+
multimodal=multimodal,
804+
act_bits=act_bits,
805+
act_group_size=act_group_size,
806+
act_sym=act_sym,
807+
act_dynamic=act_dynamic,
808+
low_cpu_mem_usage=use_layer_wise,
788809
**kwargs,
789810
)
790811
qdq_model, weight_config = rounder.quantize()

neural_compressor/model/torch_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def export_compressed_model(
609609

610610
self.model = pack_model(
611611
self.model,
612-
weight_config=autoround_config,
612+
layer_config=autoround_config,
613613
enable_full_range=enable_full_range,
614614
compression_dtype=compression_dtype,
615615
compression_dim=compression_dim,

test/adaptor/pytorch_adaptor/test_weight_only_adaptor_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ def test_AutoRound_quant(self):
760760
tokenizer = transformers.AutoTokenizer.from_pretrained(
761761
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
762762
)
763-
dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=20)
763+
dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=20)
764764
fp32_model = copy.deepcopy(self.gptj)
765765
conf = PostTrainingQuantConfig(
766766
approach="weight_only",

test/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
--find-links https://download.pytorch.org/whl/torch_stable.html
22
accelerate==0.21.0
3-
auto-round
3+
auto-round @ git+https://github.com/intel/auto-round.git@24b2e74070f2b4e6f26ff069ec75af74cf5b177c
44
dynast==1.6.0rc1
55
horovod
66
intel-extension-for-pytorch

0 commit comments

Comments
 (0)