Skip to content

Commit 019c6ba

Browse files
committed
update calibration.py
Signed-off-by: yuwenzho <[email protected]>
1 parent 4a2281d commit 019c6ba

File tree

1 file changed

+59
-48
lines changed

1 file changed

+59
-48
lines changed

neural_compressor/adaptor/ox_utils/calibration.py

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def augment_graph(self):
150150
new_white_nodes.append(new_white_node)
151151
self.white_nodes = new_white_nodes
152152

153-
# initializers = {i.name: i.data_type for i in model.graph.initializer}
154153
node_outputs = []
155154
for node in model.graph.node: # pylint: disable=no-member
156155
node_outputs.extend(node.output)
@@ -166,42 +165,16 @@ def augment_graph(self):
166165
tensors_to_dump.add(input)
167166
# add output tensors which should be dump
168167
tensors_to_dump.update([output for output in node.output if len(output) != 0])
169-
170-
# # calibrate output tensors
171-
# if not weight_only and not activation_only:
172-
# # update input tensors which should be dump
173-
# for input in node.input:
174-
# if len(input) != 0:
175-
# initializer_tensor = self.model_wrapper.get_initializer(input)
176-
# if initializer_tensor is None:
177-
# tensors_to_dump.add(input)
178-
# # update output tensors which should be dump
179-
# tensors_to_dump.update([output for output in node.output if len(output) != 0])
180-
# elif weight_only:
181-
# for input in node.input:
182-
# if (
183-
# self.already_quantized
184-
# and input.replace("_dequantized", "_quantized") in initializers
185-
# and len(input) != 0
186-
# ):
187-
# initializer_tensor = self.model_wrapper.get_initializer(input)
188-
# if initializer_tensor is None:
189-
# tensors_to_dump.add(input)
190-
# elif activation_only:
191-
# if len(node.input[0]) != 0:
192-
# tensors_to_dump.update([node.input[0]])
193168

194169
model_inputs = [i.name for i in model.graph.input]
195-
logger.debug("tensors to dump:")
196-
logger.debug(tensors_to_dump)
197170
for tensor in tensors_to_dump:
198171
if tensor not in node_outputs and tensor not in model_inputs:
199172
continue
200173
if self.augment_nodes:
201174
for augment_node_type in self.augment_nodes:
202175
if augment_node_type in ["DequantizeLinear"]:
203176
# insert DequantizeLinear node as output
204-
if tensor.endswith("_scale") or tensor.endswith("_zero_point"):
177+
if tensor.endswith("_scale") or tensor.endswith("_zero_point"): # pragma: no cover
205178
continue
206179

207180
if not self.dynamically_quantized:
@@ -256,6 +229,14 @@ def augment_graph(self):
256229
)
257230

258231
def get_activation_tensors_calib_range(self, q_config=None):
232+
"""Get calib ranges of activation tensors.
233+
234+
Args:
235+
q_config (dict, optional): quantization config. Defaults to None.
236+
237+
Returns:
238+
dict: calib ranges
239+
"""
259240
# conduct inference session and get intermediate outputs
260241
so = onnxruntime.SessionOptions()
261242
so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
@@ -397,27 +378,48 @@ def _collect_data(ort_inputs):
397378
return activation_tensors_calib_range
398379

399380
def get_weight_tensors_calib_range(self):
400-
initializer_tensors_to_dump = set()
381+
"""Get calib ranges of weight tensors.
382+
383+
Returns:
384+
dict: calib ranges
385+
"""
386+
model_nodes_names = [node.name for node in self.model.graph.node]
387+
388+
# if augmented_model is not None, it means self.white_nodes is already updated in augment_graph func
389+
# then skip update here
390+
if self.already_quantized and self.augmented_model is None:
391+
# mapping between fp32 node and int8 node
392+
new_white_nodes = []
393+
for white_node in self.white_nodes:
394+
new_white_node = white_node + "_quant"
395+
assert new_white_node in model_nodes_names, "no quantized {} in the " "graph".format(white_node)
396+
new_white_nodes.append(new_white_node)
397+
self.white_nodes = new_white_nodes
398+
399+
added_outputs = set()
400+
initializer_tensors_to_dump = []
401401
initializers = [init.name for init in self.model.graph.initializer]
402402
for node in self.model.graph.node: # pylint: disable=no-member
403403
should_be_dump = ((node.op_type in self.dump_op_types) and (node.name not in self.black_nodes)) or (
404404
node.name in self.white_nodes
405405
)
406406
if should_be_dump:
407407
for input in node.input:
408-
if (
409-
self.already_quantized
410-
and input.replace("_dequantized", "_quantized") in initializers
411-
and len(input) != 0
412-
) or (
413-
not self.already_quantized
414-
and input in initializers
415-
and len(input) != 0
416-
):
417-
initializer_tensors_to_dump.add(input)
418-
419-
logger.debug("initializer tensors to dump:")
420-
logger.debug(initializer_tensors_to_dump)
408+
if ((self.already_quantized and input.replace("_dequantized", "_quantized") in initializers) or
409+
(not self.already_quantized and input in initializers)) and len(input) != 0:
410+
added_outputs.add(input)
411+
412+
for tensor in added_outputs:
413+
if tensor not in initializers:
414+
continue
415+
if self.augment_nodes:
416+
for augment_node_type in self.augment_nodes:
417+
if augment_node_type in ["DequantizeLinear"]:
418+
if not (tensor.endswith("_scale") or tensor.endswith("_zero_point")):
419+
initializer_tensors_to_dump.append(tensor)
420+
else:
421+
initializer_tensors_to_dump.append(tensor)
422+
421423
weight_tensors_calib_range = {}
422424
for initializer_tensor_name in initializer_tensors_to_dump:
423425
initializer_tensor = self.model_wrapper.get_initializer(initializer_tensor_name)
@@ -436,7 +438,6 @@ def get_weight_tensors_calib_range(self):
436438

437439
def get_intermediate_outputs(self, q_config=None, activation_only=False, weight_only=False):
438440
"""Gather intermediate model outputs after running inference."""
439-
440441
output_dicts = {}
441442
if not activation_only and not weight_only:
442443
output_dicts = self.get_activation_tensors_calib_range(q_config)
@@ -543,7 +544,12 @@ def _map_calibration(self, node_output_names, output_dicts):
543544
return final_dict
544545

545546
def dump_minmax(self, q_config):
546-
"""Get min/max values of tensors."""
547+
"""Get calib ranges of tensors."""
548+
# pipeline of getting calib ranges of tensors during calibration:
549+
# 1. augment_graph(): insert activation tensors to model output
550+
# 2. get_intermediate_outputs():
551+
# 2.1 get_activation_tensors_calib_range(): get calib ranges of activation tensors using the augmnet graph
552+
# 2.2 get_weight_tensors_calib_range(): get calib ranges of weight tensors
547553
self.augment_graph()
548554
node_output_names, output_dicts = self.get_intermediate_outputs(q_config)
549555
return self._map_calibration(node_output_names, output_dicts)
@@ -624,15 +630,20 @@ def dump_tensor(self, activation=True, weight=False, format=None):
624630
self.already_quantized = True
625631
self.dynamically_quantized = "DynamicQuantizeLinear" in [node.op_type for node in self.model.graph.node]
626632
is_qdq = format == "qdq"
627-
self.augment_graph()
633+
if activation:
634+
self.augment_graph(inspect_tensor=True) # add activation tensors to model output
628635
_, output_dicts = self.get_intermediate_outputs(activation_only=not weight, weight_only=not activation)
629636
iters = len(list(output_dicts.values())[-1])
630637
map_node_activation = [{} for _ in range(iters)]
631638
map_node_weight = {}
632639
self.white_nodes = [node.replace("_quant", "") for node in self.white_nodes]
633-
augmengted_wrapper = ONNXModel(self.augmented_model)
634-
map_output = augmengted_wrapper.output_name_to_node
635-
map_input = augmengted_wrapper.input_name_to_nodes
640+
641+
if activation and self.augmented_model is None:
642+
raise ValueError("augmented model should not be None when dump activation tensors.")
643+
# if activation tensors are not dumped, then use origin model wrapper
644+
model_wrapper = ONNXModel(self.augmented_model) if activation else self.model_wrapper
645+
map_output = model_wrapper.output_name_to_node
646+
map_input = model_wrapper.input_name_to_nodes
636647
model_output_names = [t.name for t in self.model.graph.output]
637648
model_input_names = [t.name for t in self.model.graph.input]
638649
model_initializer_names = [t.name for t in self.model.graph.initializer]

0 commit comments

Comments
 (0)