Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion neural_compressor/torch/algorithms/smooth_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from neural_compressor.torch.algorithms.static_quant import (
CpuInfo,
Statistics,
TransformerBasedModelBlockPatternDetector,
dump_model_op_stats,
generate_activation_observer,
get_quantizable_ops_from_cfgs,
ipex_config_path,
Expand Down Expand Up @@ -251,6 +251,59 @@ def cfg_to_qconfig(
return None


def dump_model_op_stats(user_cfg):
"""This is a function to dump quantizable ops of model to user.

Args:
user_cfg (dict): quantization config
Returns:
None
"""
res = dict()
for k, v in user_cfg.items():
op_type_list = k[-1].split("><")
op_type = ""
for op in op_type_list:
if "class" in op:
op_type = (
op[op.rfind(".") + 1 : op.rfind("'")]
if op_type == ""
else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")]
)
elif "method" in op:
start = op.find("'") + 1
if start > 1:
op_type = (
op[start : op.find("'", start)]
if op_type == ""
else op_type + "&" + op[start : op.find("'", start)]
)
else:
start = op.find("method") + 7
op_type = (
op[start : op.find(" ", start)]
if op_type == ""
else op_type + "&" + op[start : op.find(" ", start)]
)
else:
op_type = op if op_type == "" else op_type + "&" + op
if op_type not in res.keys():
res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0}
if v["weight"]["dtype"] == "int8":
res[op_type]["INT8"] += 1
elif v["weight"]["dtype"] == "fp32":
res[op_type]["FP32"] += 1

output_data = [
[op_type, sum(res[op_type].values()), res[op_type]["INT8"], res[op_type]["BF16"], res[op_type]["FP32"]]
for op_type in res.keys()
]

Statistics(
output_data, header="Mixed Precision Statistics", field_names=["Op Type", "Total", "INT8", "BF16", "FP32"]
).print_stat()


def get_parent(node, all_parents=False): # pragma: no cover
if node.inputs() is None:
return None
Expand Down
48 changes: 15 additions & 33 deletions neural_compressor/torch/algorithms/static_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,19 @@
"<class 'torch.nn.modules.conv.Conv2d'>": "Conv2d",
"<class 'torch.nn.modules.conv.Conv3d'>": "Conv3d",
"<class 'torch.nn.modules.activation.ReLU'>": "ReLU",
"<class 'torch.nn.modules.sparse.EmbeddingBag'>": "EmbeddingBag",
"<method 'add' of 'torch._C._TensorBase' objects>": "add", # for IPEX < 2.2
"<method 'add' of 'torch._C.TensorBase' objects>": "add", # for IPEX >= 2.2
"<class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>": "AdaptiveAvgPool2d",
"Linear_Relu": "Linear",
"Linear_add": "Linear",
"<class 'torch.nn.modules.linear.Linear'>": "Linear",
"<class 'torch.nn.modules.pooling.MaxPool2d'>": "MaxPool2d",
"re": {"<built-in method matmul of type object at": "matmul"},
"re": {
"<built-in method matmul of type object at": "matmul",
"<built-in method add of type object at": "add",
"<built-in method bmm of type object at": "bmm",
},
}

BLOCK_PATTERNS = [
Expand Down Expand Up @@ -85,6 +91,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
Returns:
cfgs (dict): updated configs.
"""
ori_user_cfg = copy.deepcopy(user_cfg)
tmp_user_cfg = OrderedDict()
for op in user_cfg: # map ipex op_name to pt op_name
for i, op_name in enumerate(op):
Expand All @@ -94,9 +101,9 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
ori_op = (tuple(ops), unify_op_type_mapping_ipex[op_infos_from_cfgs[ops]["op_type"]])
tmp_user_cfg[((ori_op[0],), ori_op[1])] = user_cfg[op]
break
user_cfg = tmp_user_cfg
for op_name in user_cfg:
inc_op_cfg = user_cfg[op_name]

for op_name in tmp_user_cfg:
inc_op_cfg = tmp_user_cfg[op_name]
for i, name in enumerate(op_name[0]):
# to int8
ipex_op_cfg = op_infos_from_cfgs[name]
Expand Down Expand Up @@ -154,7 +161,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
else:
pass
cfgs[name[0]][name[1]][name[2]] = ipex_op_cfg
return cfgs, user_cfg
return cfgs, ori_user_cfg


def generate_activation_observer(scheme, algorithm, smooth_quant=False, smooth_quant_enable=False): # pragma: no cover
Expand Down Expand Up @@ -333,8 +340,8 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
elif "method" in ipex_op_type: # "<method 'add' of 'torch._C._TensorBase' objects>"
method = ipex_op_type.split("'")[1]
op_name_info.append((module_fqn, method))
elif "Convolution" in ipex_op_type: # "Convolution_Relu"
op_name_info.append((module_fqn, "Conv2d"))
elif "_" in ipex_op_type: # "Convolution_Relu", "Linear_Relu"
op_name_info.append((module_fqn, ipex_op_type.split("_")[0]))
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex["re"].items():
Expand Down Expand Up @@ -394,32 +401,7 @@ def dump_model_op_stats(user_cfg):
"""
res = dict()
for k, v in user_cfg.items():
op_type_list = k[-1].split("><")
op_type = ""
for op in op_type_list:
if "class" in op:
op_type = (
op[op.rfind(".") + 1 : op.rfind("'")]
if op_type == ""
else op_type + "&" + op[op.rfind(".") + 1 : op.rfind("'")]
)
elif "method" in op:
start = op.find("'") + 1
if start > 1:
op_type = (
op[start : op.find("'", start)]
if op_type == ""
else op_type + "&" + op[start : op.find("'", start)]
)
else:
start = op.find("method") + 7
op_type = (
op[start : op.find(" ", start)]
if op_type == ""
else op_type + "&" + op[start : op.find(" ", start)]
)
else:
op_type = op if op_type == "" else op_type + "&" + op
op_type = k[1]
if op_type not in res.keys():
res[op_type] = {"INT8": 0, "BF16": 0, "FP32": 0}
if v["weight"]["dtype"] == "int8":
Expand Down
18 changes: 12 additions & 6 deletions test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(30, 50)
self.fc2 = torch.nn.Linear(50, 30)
self.fc3 = torch.nn.Linear(30, 5)
self.fc2 = torch.nn.Linear(50, 50)
self.fc3 = torch.nn.Linear(50, 30)
self.fc4 = torch.nn.Linear(30, 5)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.fc1(x)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
out = out + x
out = self.fc4(out)
return out

model = Model()
Expand Down Expand Up @@ -78,21 +83,22 @@ def test_static_quant_fallback(self):
assert q_model is not None, "Quantization failed!"

for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items():
if op_info["op_type"] == "<class 'torch.nn.modules.linear.Linear'>":
if op_info["op_type"] == "Linear":
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
assert dtype == "torch.float32", "Failed to fallback linear op, please check!"

# fallback by op_name
quant_config.set_local("fc1", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
quant_config = get_default_static_config()
quant_config.set_local("fc2", StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)
assert q_model is not None, "Quantization failed!"

for op, op_info in q_model.tune_cfg[" "]["q_op_infos"].items():
if op_info["fqn"] == "fc1":
if op_info["fqn"] == "fc2":
dtype = q_model.tune_cfg[" "]["q_op_infos"][op]["input_tensor_infos"][0]["force_dtype"]
assert dtype == "torch.float32", "Failed to fallback fc1 layer, please check!"
assert dtype == "torch.float32", "Failed to fallback fc2 layer, please check!"

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
@pytest.mark.parametrize(
Expand Down