Skip to content

Commit 009c3b0

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 74a8362 commit 009c3b0

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

neural_compressor/adaptor/ox_utils/calibration.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def augment_graph(self):
159159
if should_be_dump:
160160
# add input tensors which should be dump
161161
for input in node.input:
162-
if len(input) != 0: # to prevent input is ""
162+
if len(input) != 0: # to prevent input is ""
163163
initializer_tensor = self.model_wrapper.get_initializer(input)
164164
if initializer_tensor is None:
165165
tensors_to_dump.add(input)
@@ -174,7 +174,7 @@ def augment_graph(self):
174174
for augment_node_type in self.augment_nodes:
175175
if augment_node_type in ["DequantizeLinear"]:
176176
# insert DequantizeLinear node as output
177-
if tensor.endswith("_scale") or tensor.endswith("_zero_point"): # pragma: no cover
177+
if tensor.endswith("_scale") or tensor.endswith("_zero_point"): # pragma: no cover
178178
continue
179179

180180
if not self.dynamically_quantized:
@@ -333,8 +333,9 @@ def _collect_data(ort_inputs):
333333
# per iteration in the future.
334334
if calibrator.method_name == "minmax":
335335
calibrator.collect(output)
336-
activation_tensors_calib_range[node_output_names[output_idx]] = \
337-
[list(calibrator.calib_range)]
336+
activation_tensors_calib_range[node_output_names[output_idx]] = [
337+
list(calibrator.calib_range)
338+
]
338339
name_to_calibrator[node_output_names[output_idx]] = calibrator
339340
else:
340341
intermediate_tensor.setdefault((node_output_names[output_idx], node_name), []).append(
@@ -376,7 +377,7 @@ def _collect_data(ort_inputs):
376377
self._dataloder_for_next_split_model = ort_inputs_for_next_split_model
377378

378379
return activation_tensors_calib_range
379-
380+
380381
def get_weight_tensors_calib_range(self):
381382
"""Get calib ranges of weight tensors.
382383
@@ -405,8 +406,10 @@ def get_weight_tensors_calib_range(self):
405406
)
406407
if should_be_dump:
407408
for input in node.input:
408-
if ((self.already_quantized and input.replace("_dequantized", "_quantized") in initializers) or
409-
(not self.already_quantized and input in initializers)) and len(input) != 0:
409+
if (
410+
(self.already_quantized and input.replace("_dequantized", "_quantized") in initializers)
411+
or (not self.already_quantized and input in initializers)
412+
) and len(input) != 0:
410413
added_outputs.add(input)
411414

412415
for tensor in added_outputs:
@@ -429,13 +432,13 @@ def get_weight_tensors_calib_range(self):
429432
continue
430433

431434
initializer_tensor = numpy_helper.to_array(initializer_tensor)
432-
calibrator = CALIBRATOR["minmax"]() # use minmax method to calibrate initializer tensors
435+
calibrator = CALIBRATOR["minmax"]() # use minmax method to calibrate initializer tensors
433436
calibrator.collect(initializer_tensor)
434437
weight_tensors_calib_range[initializer_tensor_name] = [list(calibrator.calib_range)]
435438
calibrator.clear()
436439
del calibrator
437440
return weight_tensors_calib_range
438-
441+
439442
def get_intermediate_outputs(self, q_config=None, activation_only=False, weight_only=False):
440443
"""Gather intermediate model outputs after running inference."""
441444
output_dicts = {}
@@ -631,7 +634,7 @@ def dump_tensor(self, activation=True, weight=False, format=None):
631634
self.dynamically_quantized = "DynamicQuantizeLinear" in [node.op_type for node in self.model.graph.node]
632635
is_qdq = format == "qdq"
633636
if activation:
634-
self.augment_graph(inspect_tensor=True) # add activation tensors to model output
637+
self.augment_graph(inspect_tensor=True) # add activation tensors to model output
635638
_, output_dicts = self.get_intermediate_outputs(activation_only=not weight, weight_only=not activation)
636639
iters = len(list(output_dicts.values())[-1])
637640
map_node_activation = [{} for _ in range(iters)]

0 commit comments

Comments
 (0)