Skip to content

Commit 4261836

Browse files
committed
add sq doc
Signed-off-by: Cheng, Zixuan <[email protected]>
2 parents 6a828ae + 004af16 commit 4261836

File tree

36 files changed

+845
-293
lines changed

36 files changed

+845
-293
lines changed

.azure-pipelines/scripts/ut/env_setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
9292
fi
9393

9494
if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
95-
pip install auto-round
95+
pip install git+https://github.com/intel/auto-round.git@ecca5349981044e1278773a251b3fc5c0a11fe7b
9696
fi
9797

9898
# test deps

docs/source/3x/PT_MXQuant.md

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
Microscaling Quantization
2+
===============
3+
4+
1. [Introduction](#introduction)
5+
2. [Get Started with Microscaling Quantization API](#get-start-with-microscaling-quantization-api)
6+
3. [Examples](#examples)
7+
4. [Reference](#reference)
8+
9+
## Introduction
10+
11+
Numerous breakthroughs have emerged across various fields, such as text analysis, language translation and chatbot technologies, fueled by the development of large language models (LLMs). Nevertheless, their increasing power comes with the challenge of explosive growth in parameters, posing obstacles for practical use. To balance memory limits and accuracy preservation for AI models, the Microscaling (MX) specification was promoted from the well-known Microsoft Floating Point (MSFP) data type [1, 2]:
12+
13+
<table>
14+
<tr>
15+
<th>Format Name</th>
16+
<th>Element Data type</th>
17+
<th>Element Bits</th>
18+
<th>Scaling Block Size</th>
19+
<th>Scale Data Type</th>
20+
<th>Scale Bits</th>
21+
</tr>
22+
<tr>
23+
<td rowspan="2">MXFP8</td>
24+
<td>FP8 (E5M2)</td>
25+
<td rowspan="2">8</td>
26+
<td rowspan="2">32</td>
27+
<td rowspan="2">E8M0</td>
28+
<td rowspan="2">8</td>
29+
</tr>
30+
<tr>
31+
<td>FP8 (E4M3)</td>
32+
</tr>
33+
<tr>
34+
<td rowspan="2">MXFP6</td>
35+
<td>FP6 (E3M2)</td>
36+
<td rowspan="2">6</td>
37+
<td rowspan="2">32</td>
38+
<td rowspan="2">E8M0</td>
39+
<td rowspan="2">8</td>
40+
</tr>
41+
<tr>
42+
<td>FP6 (E2M3)</td>
43+
</tr>
44+
<tr>
45+
<td>MXFP4</td>
46+
<td>FP4 (E2M1)</td>
47+
<td>4</td>
48+
<td>32</td>
49+
<td>E8M0</td>
50+
<td>8</td>
51+
</tr>
52+
<tr>
53+
<td>MXINT8</td>
54+
<td>INT8</td>
55+
<td>8</td>
56+
<td>32</td>
57+
<td>E8M0</td>
58+
<td>8</td>
59+
</tr>
60+
</table>
61+
62+
63+
At an equivalent accuracy level, the MX data type demonstrates the ability to occupy a smaller area and incur lower energy costs for multiply-accumulate compared to other conventional data types on the same silicon [1].
64+
65+
Neural Compressor seamlessly applies the MX data type to post-training quantization, offering meticulously crafted recipes to empower users to quantize LLMs without sacrificing accuracy. The workflow is shown as below.
66+
67+
<a target="_blank" href="../imgs/mx_workflow.png" text-align:left>
68+
<left>
69+
<img src="../imgs/mx_workflow.png" alt="Workflow of MX Quant (source [3])" height=120>
70+
</left>
71+
</a>
72+
73+
The memory and computational limits of LLMs are more severe than other general neural networks, so our exploration focuses on LLMs first. The following table shows the basic MX quantization recipes in Neural Compressor and enumerates distinctions among various data types. The MX data type replaces general float scale with powers of two to be more hardware-friendly. It adapts a granularity falling between per-channel and per-tensor to balance accuracy and memory consumption.
74+
75+
| | MX Format | INT8 | FP8 |
76+
|------------|--------------|------------|------------|
77+
| Scale | $2^{exp}$ | $\frac{MAX}{amax}$ | $\frac{MAX}{amax}$ |
78+
| Zero point | 0 (None) | $2^{bits - 1}$ or $-min * scale$ | 0 (None) |
79+
| Granularity | per-block (default blocksize is 32) | per-channel or per-tensor | per-tensor |
80+
81+
The exponent (exp) is equal to torch.floor(torch.log2(amax)), MAX is the representation range of the data type, amax is the max absolute value of per-block tensor, and rmin is the minimum value of the per-block tensor.
82+
83+
84+
## Get Started with Microscaling Quantization API
85+
86+
To get a model quantized with Microscaling Data Types, users can use the Microscaling Quantization API as follows.
87+
88+
```python
89+
from neural_compressor.torch.quantization import MXQuantConfig, quantize
90+
91+
quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq)
92+
user_model = quantize(model=user_model, quant_config=quant_config)
93+
```
94+
95+
## Examples
96+
97+
- PyTorch [huggingface models](/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx)
98+
99+
100+
## Reference
101+
102+
[1]: Darvish Rouhani, Bita, et al. "Pushing the limits of narrow precision inferencing at cloud scale with microsoft floating point." Advances in neural information processing systems 33 (2020): 10271-10281
103+
104+
[2]: OCP Microscaling Formats (MX) Specification
105+
106+
[3]: Rouhani, Bita Darvish, et al. "Microscaling Data Formats for Deep Learning." arXiv preprint arXiv:2310.10537 (2023).

docs/source/3x/PT_SmoothQuant.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,27 +341,25 @@ To set a fixed alpha for the entire model, users can follow this example:
341341
```python
342342
from neural_compressor.torch.quantization import SmoothQuantConfig, convert, prepare
343343

344-
quant_config = SmoothQuantConfig(alpha=0.5, folding=False)
345-
example_inputs = torch.zeros([1, 3])
346-
347344

348345
def run_fn(model):
349346
model(example_inputs)
350347

351348

349+
quant_config = SmoothQuantConfig(alpha=0.5, folding=False)
352350
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
353351
run_fn(prepared_model)
354352
q_model = convert(prepared_model)
355353
```
356354
`SmoothQuantConfig` description:
357355

358-
"alpha": a float value. Default is 0.5.
356+
`alpha`: a float value. Default is 0.5.
359357

360-
"folding": whether to fold mul into the previous layer, where mul is required to update the input distribution during smoothing.
358+
`folding`: whether to fold mul into the previous layer, where mul is required to update the input distribution during smoothing.
361359
- True: Fold inserted mul into the previous layer. IPEX will only insert mul for layers can do folding.
362360
- False: Allow inserting mul to update the input distribution and no folding. IPEX (version>=2.1) can fuse inserted mul automatically. For Stock PyTorch, setting folding=False will convert the model to a QDQ model.
363361

364-
To get more information, please refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x/pytorch/nlp/huggingface_models/language-modeling/quantization/llm).
362+
To get more information, please refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm).
365363

366364

367365
## Supported Framework Matrix

docs/source/quantization_weight_only.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ Notes:
8787
| use_max_length | False | Whether to align all calibration data to fixed length, which equals to pad_max_length. |
8888
| block_size | 128 | Execute GPTQ quantization per block, block shape = [$C_{out}$, block_size] |
8989
| static_groups | False | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements |
90+
| true_sequential | False | Whether to quantize layers within a transformer block in their original order. This can lead to higher accuracy but slower overall quantization process. |
91+
| lm_head | False | Whether to quantize the lm_head (linear layer related to prediction in the end of the language models). |
9092

9193
**Note:** Neural compressor provides `Unsigned integer for asymmetric quantization` and `Signed integer for symmetric quantization`. Please follow the below section to compress the low bit data type for saving.
9294

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import datasets
1212
from torch.nn.functional import pad
1313
from torch.utils.data import DataLoader
14-
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
14+
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
1515

1616
parser = argparse.ArgumentParser()
1717
parser.add_argument(
@@ -377,7 +377,9 @@ def run_fn(model):
377377

378378
from neural_compressor.torch.quantization import load
379379
tokenizer = AutoTokenizer.from_pretrained(args.model)
380+
config = AutoConfig.from_pretrained(args.model)
380381
user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)))
382+
setattr(user_model, "config", config)
381383
else:
382384
user_model, tokenizer = get_user_model()
383385

examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/ptq_static/main.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import onnxruntime as ort
2727
from torch.nn.functional import pad
2828
from torch.utils.data import DataLoader
29-
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
3029
from optimum.onnxruntime import ORTModelForCausalLM
3130
from transformers import LlamaConfig, LlamaTokenizer
3231

@@ -198,28 +197,33 @@ def replace_architectures(json_path):
198197
json.dump(data, file, indent=4)
199198

200199
def eval_func(model):
200+
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
201+
201202
model_dir = model
202203
if isinstance(model, str) and model.endswith(".onnx"):
203204
model_dir = os.path.dirname(model)
204205

205206
replace_architectures(os.path.join(model_dir, "config.json"))
206207

207-
results = evaluate(
208-
model="hf-causal",
209-
model_args="pretrained=" + model_dir + ",tokenizer="+ args.tokenizer,
208+
eval_args = LMEvalParser(
209+
model="hf",
210+
model_args="pretrained=" + model_dir + ",tokenizer=" + args.tokenizer + ",model_format=onnx",
210211
batch_size=args.batch_size,
211-
tasks=args.tasks,
212-
model_format="onnx",
212+
tasks=','.join(args.tasks),
213+
device="cpu",
213214
)
215+
results = evaluate(eval_args)
214216

215217
eval_acc = 0
216218
for task_name in args.tasks:
217219
if task_name == "wikitext":
218-
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]))
219-
eval_acc += results["results"][task_name]["word_perplexity"]
220+
print("Accuracy for %s is: %s" %
221+
(task_name, results["results"][task_name]["word_perplexity,none"]))
222+
eval_acc += results["results"][task_name]["word_perplexity,none"]
220223
else:
221-
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]))
222-
eval_acc += results["results"][task_name]["acc"]
224+
print("Accuracy for %s is: %s" %
225+
(task_name, results["results"][task_name]["acc,none"]))
226+
eval_acc += results["results"][task_name]["acc,none"]
223227

224228
if len(args.tasks) != 0:
225229
eval_acc /= len(args.tasks)

examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/ptq_static/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ onnxruntime-extensions; python_version < '3.11'
77
datasets
88
optimum
99
evaluate
10-
intel-extension-for-transformers
10+
intel-extension-for-transformers >= 1.4.1
1111
peft
12-
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
12+
lm-eval==0.4.2

examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/weight_only/main.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import onnxruntime as ort
2828
from torch.nn.functional import pad
2929
from torch.utils.data import DataLoader
30-
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
3130
from optimum.onnxruntime import ORTModelForCausalLM
3231
from transformers import LlamaConfig, LlamaTokenizer
3332

@@ -135,28 +134,33 @@ def replace_architectures(json_path):
135134
json.dump(data, file, indent=4)
136135

137136
def eval_func(model):
137+
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
138+
138139
model_dir = model
139140
if isinstance(model, str) and model.endswith(".onnx"):
140141
model_dir = os.path.dirname(model)
141142

142143
replace_architectures(os.path.join(model_dir, "config.json"))
143144

144-
results = evaluate(
145-
model="hf-causal",
146-
model_args="pretrained=" + model_dir + ",tokenizer="+ args.tokenizer,
145+
eval_args = LMEvalParser(
146+
model="hf",
147+
model_args="pretrained=" + model_dir + ",tokenizer=" + args.tokenizer + ",model_format=onnx",
147148
batch_size=args.batch_size,
148-
tasks=args.tasks,
149-
model_format="onnx",
149+
tasks=','.join(args.tasks),
150+
device="cpu",
150151
)
152+
results = evaluate(eval_args)
151153

152154
eval_acc = 0
153155
for task_name in args.tasks:
154156
if task_name == "wikitext":
155-
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]))
156-
eval_acc += results["results"][task_name]["word_perplexity"]
157+
print("Accuracy for %s is: %s" %
158+
(task_name, results["results"][task_name]["word_perplexity,none"]))
159+
eval_acc += results["results"][task_name]["word_perplexity,none"]
157160
else:
158-
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]))
159-
eval_acc += results["results"][task_name]["acc"]
161+
print("Accuracy for %s is: %s" %
162+
(task_name, results["results"][task_name]["acc,none"]))
163+
eval_acc += results["results"][task_name]["acc,none"]
160164

161165
if len(args.tasks) != 0:
162166
eval_acc /= len(args.tasks)

examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/weight_only/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ onnxruntime-extensions; python_version < '3.11'
77
datasets
88
optimum
99
evaluate
10-
intel-extension-for-transformers
10+
intel-extension-for-transformers >= 1.4.1
1111
peft
12-
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
12+
lm-eval==0.4.2

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777
this should align with your model config, \
7878
and your dataset builder args: args.pad_max_length')
7979
parser.add_argument('--gptq_static_groups', action='store_true', help='Use determined group to do quantization')
80+
parser.add_argument('--gptq_true_sequential', action='store_true', help="Whether to run in true_sequential model.")
81+
parser.add_argument('--gptq_lm_head', action='store_true', help="Whether to use GPTQ to quantize the output layer of the LLMs.")
8082
# ==============code generation args===========
8183
parser.add_argument("--code_generation", action="store_true")
8284
parser.add_argument("--n_samples", default=200, type=int)
@@ -278,7 +280,8 @@ def calib_func(prepared_model):
278280
'use_max_length': args.gptq_use_max_length,
279281
'pad_max_length': args.gptq_pad_max_length,
280282
'static_groups': args.gptq_static_groups,
281-
"enable_mse_search": args.woq_enable_mse_search,
283+
"true_sequential": args.gptq_true_sequential,
284+
"lm_head": args.gptq_lm_head,
282285
}
283286
# GPTQ: use assistive functions to modify calib_dataloader and calib_func
284287
# TEQ: set calib_func=None, use default training func as calib_func
@@ -340,12 +343,13 @@ def eval_func(model):
340343

341344
if args.ipex:
342345
user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)))
346+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
343347
else:
344-
user_model, _ = get_user_model()
348+
user_model, tokenizer = get_user_model()
345349
kwargs = {'weight_only': True} if args.approach == 'weight_only' else {}
346350
user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)), user_model, **kwargs)
347351
else:
348-
user_model, _ = get_user_model()
352+
user_model, tokenizer = get_user_model()
349353

350354
if args.accuracy:
351355
user_model.eval()

0 commit comments

Comments
 (0)