Skip to content

Commit 83a4569

Browse files
wenhuach21yintong-lupre-commit-ci[bot]
authored
WAQ refactor[WIP] (#1496)
* fix conflict Signed-off-by: Lu, Yintong <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Lu, Yintong <[email protected]> Co-authored-by: Lu, Yintong <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2a86aea commit 83a4569

File tree

18 files changed

+2109
-1680
lines changed

18 files changed

+2109
-1680
lines changed

.azure-pipelines/scripts/codeScan/pylint/pylint.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pip install torch \
3030
fvcore \
3131
pymoo \
3232
onnxruntime_extensions \
33+
peft \
3334
tf_slim \
3435
transformers \
3536
accelerate \

docs/source/smooth_quant.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ In our experiments, an $\alpha$ range of [0.0, 1.0] with a step_size of 0.1 is f
304304
*fully automated*: users only need to pass a model and dataloader.
305305

306306
```python
307-
from neural_compressor.adaptor.torch_utils.smooth_quant import TorchSmoothQuant
307+
from neural_compressor.adaptor.torch_utils.waq import TorchSmoothQuant
308308

309309
sq = TorchSmoothQuant(model, dataloader)
310310
alpha = "auto" ##alpha could be a float number to disable auto-tuning and enable fixed-value alpha smoothquant.

neural_compressor/adaptor/pytorch.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import math
2121
import os
2222
import re
23-
from collections import OrderedDict, UserDict, namedtuple
23+
from collections import OrderedDict, UserDict
2424
from functools import partial
2525

2626
import yaml
@@ -1800,7 +1800,7 @@ def smooth_quant(
18001800
assert folding, "IPEX version >= 2.1 is required for SmoothQuant folding=False."
18011801

18021802
if not hasattr(self, "sq") or force_re_smooth:
1803-
from .torch_utils.smooth_quant import TorchSmoothQuant
1803+
from neural_compressor.adaptor.torch_utils.waq import TorchSmoothQuant
18041804

18051805
self.sq = TorchSmoothQuant(
18061806
model._model, dataloader=dataloader, example_inputs=self.example_inputs, q_func=self.q_func
@@ -1813,17 +1813,18 @@ def smooth_quant(
18131813
kwargs["percentile"] = percentile
18141814
if scales_per_op is not None:
18151815
kwargs["scales_per_op"] = scales_per_op
1816+
auto_alpha_args["init_alpha"] = default_alpha
18161817
model._model = self.sq.transform(
18171818
alpha=alpha,
18181819
folding=folding,
18191820
calib_iter=calib_iter,
18201821
weight_clip=weight_clip,
1821-
default_alpha=default_alpha,
18221822
auto_alpha_args=auto_alpha_args,
18231823
**kwargs,
18241824
)
18251825
if self.sq.record_max_info:
18261826
model.sq_max_info = self.sq.max_value_info
1827+
model.sq_scale_info = self.sq.sq_scale_info
18271828
return model
18281829

18291830
def _apply_pre_optimization(self, model, tune_cfg, recover=False):
@@ -1840,7 +1841,7 @@ def _apply_pre_optimization(self, model, tune_cfg, recover=False):
18401841
q_model = model._model
18411842
sq_max_info = model.sq_max_info
18421843
if sq_max_info:
1843-
from .torch_utils.smooth_quant import TorchSmoothQuant
1844+
from neural_compressor.adaptor.torch_utils.waq import TorchSmoothQuant
18441845

18451846
tsq = TorchSmoothQuant(q_model, None)
18461847
alpha = tune_cfg["recipe_cfgs"]["smooth_quant_args"]["alpha"]
@@ -1876,8 +1877,9 @@ def qdq_quantize(self, model, tune_cfg):
18761877
model: qdq quantized model.
18771878
"""
18781879
q_model = model._model
1880+
from neural_compressor.adaptor.torch_utils.waq import get_module, set_module
1881+
18791882
from .torch_utils.model_wrapper import QDQLinear, SQLinearWrapper
1880-
from .torch_utils.smooth_quant import get_module, set_module
18811883

18821884
smoothquant_scale_info = {}
18831885
fallback_op_name_list = []
@@ -3317,37 +3319,7 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
33173319
inplace = True if self.performance_only else False
33183320

33193321
# fetch SmoothQuant scale info from pre-optimized model
3320-
sq_max_info = model.sq_max_info
3321-
if sq_max_info:
3322-
smoothquant_scale_info = {}
3323-
from .torch_utils.model_wrapper import SQLinearWrapper
3324-
from .torch_utils.smooth_quant import get_module
3325-
3326-
for _, info in sq_max_info.items():
3327-
alpha = info["alpha"]
3328-
absorbed_layer = info["absorbed_layer"]
3329-
input_minmax = info["input_minmax"]
3330-
# for peft model,lora_B weights is 0.
3331-
weight_max = info["weight_max"]
3332-
if self.sq.weight_clip:
3333-
weight_max = weight_max.clamp(min=1e-5)
3334-
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
3335-
input_power = torch.pow(abs_input_max, alpha)
3336-
weight_power = torch.pow(weight_max, 1 - alpha)
3337-
scale = torch.clip(input_power / weight_power, min=1e-5)
3338-
for op_name in absorbed_layer:
3339-
module = copy.deepcopy(get_module(q_model._model, op_name))
3340-
new_module = SQLinearWrapper(module, 1.0 / scale, input_minmax, alpha)
3341-
weight_scale = new_module._get_weight_scale()
3342-
smoothquant_scale_info[op_name] = {
3343-
"alpha": new_module.alpha,
3344-
"input_scale_for_mul": new_module.input_scale,
3345-
"input_scale_after_mul": new_module.scale,
3346-
"input_zero_point_after_mul": new_module.zero_point,
3347-
"input_dtype": new_module.dtype,
3348-
"weight_scale_after_mul": weight_scale,
3349-
}
3350-
logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}")
3322+
smoothquant_scale_info = model.sq_scale_info
33513323

33523324
# Check save_qconf_summary part is a workaround for IPEX bug.
33533325
# Sometimes the prepared model from get_op_capablitiy loss this attribute
@@ -4795,7 +4767,7 @@ def teq_quantize(self, model, tune_cfg, dataloader, calib_func):
47954767

47964768
supported_layers = ["Linear"]
47974769
if folding: # pragma: no cover
4798-
from .torch_utils.smooth_quant import GraphTrace
4770+
from neural_compressor.adaptor.torch_utils.waq import GraphTrace
47994771

48004772
tg = GraphTrace()
48014773
absorb_to_layer, _ = tg.get_absorb_to_layer(model, self.example_inputs, supported_layers)

neural_compressor/adaptor/torch_utils/awq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import copy
16-
from functools import partial
1716

1817
import torch
1918

@@ -25,10 +24,10 @@
2524
get_hidden_states,
2625
get_module_input_output,
2726
)
27+
from neural_compressor.adaptor.torch_utils.waq import set_module
2828

2929
from ...utils import logger
3030
from .model_wrapper import MulLinear
31-
from .smooth_quant import model_forward, set_module
3231

3332

3433
def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}):

neural_compressor/adaptor/torch_utils/layer_wise_quant/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
from torch.quantization import convert, prepare
2525
from tqdm import tqdm
2626

27+
from neural_compressor.adaptor.torch_utils.waq import TorchSmoothQuant
2728
from neural_compressor.config import default_workspace
2829

2930
from ..model_wrapper import QDQLayer
30-
from ..smooth_quant import TorchSmoothQuant
3131
from .utils import (
3232
_get_path,
3333
clean_module_weight,

neural_compressor/adaptor/torch_utils/model_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def forward(self, X):
6666

6767
def qdq_weight(self):
6868
# update weight w/ QDQ
69-
from .smooth_quant import quant_dequant_w
69+
from neural_compressor.adaptor.torch_utils.waq.utils import quant_dequant_w_v1
7070

71-
weith_qdq = quant_dequant_w(self.module)
71+
weith_qdq = quant_dequant_w_v1(self.module)
7272
self.module.weight = torch.nn.Parameter(weith_qdq)
7373

7474

@@ -139,7 +139,7 @@ def _calculate_qparams(self, input_scale, input_minmax, dtype=torch.quint8):
139139
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
140140
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
141141
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
142-
scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps]))
142+
scale = torch.max(scale, torch.tensor([torch.finfo(torch.float32).eps], device=scale.device))
143143
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
144144
zero_point = torch.clamp(zero_point, quant_min, quant_max)
145145
return scale, zero_point
@@ -181,7 +181,7 @@ def forward(self, X):
181181
return X
182182

183183
module_name_list = input_scale_dict.keys()
184-
from .smooth_quant import get_module, set_module
184+
from neural_compressor.adaptor.torch_utils.waq import get_module, set_module
185185

186186
for name in module_name_list:
187187
module = get_module(tmp_model, name)
@@ -193,7 +193,7 @@ def forward(self, X):
193193

194194
def _wrapper_qdq_linear(tmp_model, module_name_list=[]):
195195
"""Help function to generate a fake QDQ model for loading weights."""
196-
from .smooth_quant import get_module, set_module
196+
from neural_compressor.adaptor.torch_utils.waq import get_module, set_module
197197

198198
for name in module_name_list:
199199
module = get_module(tmp_model, name)

0 commit comments

Comments
 (0)