diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index 449c5988438f9..40b24c074ae46 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -202,6 +202,7 @@ def __init__( op_types_to_quantize: tuple[str, ...] | None = None, quant_axes: tuple[tuple[str, int], ...] | None = None, bits: int = 4, + channel_wised_quantize: bool = False, ): """ This is a class for weight only affine quantization configuration. @@ -236,6 +237,9 @@ def __init__( self.is_symmetric = is_symmetric self.bits = bits self.accuracy_level = accuracy_level + self.channel_wised_quantize = channel_wised_quantize + if channel_wised_quantize and quant_format == QuantFormat.QOperator: + raise NotImplementedError("QuantFormat.QOperator is not supported channel_wised_quantize yet") class NVAWQWeightOnlyQuantConfig(WeightOnlyQuantConfig): @@ -734,6 +738,26 @@ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, Gr return None, None +# transpose int4 matrix (packed as uint8) +def transpose_packed_int4_matrix(packed, rows, cols): + # unpack to int4 matrix + total = rows * cols + high = (packed >> 4) & 0x0F + low = packed & 0x0F + int4_vals = np.empty(total, dtype=np.uint8) + int4_vals[0::2] = low + int4_vals[1::2] = high + int4_matrix = int4_vals.reshape((rows, cols)) + + # transpose int4 matrix + int4_matrix_transposed = int4_matrix.T + + # pack to uint8 + flat = int4_matrix_transposed.reshape(-1) + packed = ((flat[1::2] << 4) & 0xF0) | (flat[0::2] & 0x0F) + return packed.astype(np.uint8) + + class DefaultWeightOnlyQuantizer: def __init__(self, config: DefaultWeightOnlyQuantConfig): self.config = config @@ -770,6 +794,10 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric ) else: + # block size equal to rows (K) if channel wised quantize enabled + block_size = rows if self.config.channel_wised_quantize else self.config.block_size + k_blocks = (rows + block_size - 1) // block_size + assert qbits == 4, "QDQ format only support 4 bits quantization" packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") @@ -812,6 +840,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis ) scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") + # if QDQ, CW and SYM enabled, optimize for Intel NPU, tranpose the weight to NHWC format will increase performance + qdq_opt_for_intel_npu_enabled = self.config.quant_format == QuantFormat.QDQ \ + and self.config.channel_wised_quantize and self.config.is_symmetric + if qdq_opt_for_intel_npu_enabled: + rows, cols = b_ndarray.shape + packed = transpose_packed_int4_matrix(packed, rows, cols) + scales = scales.reshape((cols, 1)) # (cols, 1) + b_quant = onnx.helper.make_tensor(b_tensor.name + f"_DQ_Q{bits}", qtype, [cols, rows], packed.tobytes(), True) + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") + for input in b_graph.input: if input.name == input_b: b_graph.input.remove(input) @@ -849,7 +887,9 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis else: dq_input_names = [b_quant.name, scales_tensor.name] dq_output_names = [b_quant.name + "_output"] - matmul_input_names = [node.input[0], dq_output_names[0]] + tp_input_names = [dq_output_names[0]] + tp_output_names = [dq_output_names[0] + "_transposed"] + matmul_input_names = [node.input[0], tp_output_names[0] if qdq_opt_for_intel_npu_enabled else dq_output_names[0]] matmul_output_names = [node.output[0]] if not self.config.is_symmetric: zp_tensor = onnx.helper.make_tensor( @@ -857,7 +897,11 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis ) dq_input_names.append(zp_tensor.name) b_graph.initializer.extend([zp_tensor]) - dq_kwargs = {"axis": 0, "block_size": self.config.block_size} + rows, cols = b_ndarray.shape + dq_kwargs = { + "axis": 1 if qdq_opt_for_intel_npu_enabled else 0, + "block_size": rows if self.config.channel_wised_quantize else self.config.block_size + } dq_node = onnx.helper.make_node( "DequantizeLinear", inputs=dq_input_names, @@ -871,7 +915,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis outputs=matmul_output_names, name=node.name + f"_matmul_Q{bits}" if node.name else "", ) - output_nodes.extend([dq_node, matmul_node]) + if qdq_opt_for_intel_npu_enabled: + tp_node = onnx.helper.make_node( + "Transpose", + inputs=tp_input_names, + outputs=tp_output_names, + perm=[1,0], + ) + output_nodes.extend([dq_node, tp_node, matmul_node]) + else: + output_nodes.extend([dq_node, matmul_node]) return output_nodes @@ -1136,6 +1189,7 @@ def __init__( quant_format=QuantFormat.QOperator, op_types_to_quantize: tuple[str, ...] | None = None, quant_axes: tuple[tuple[str, int], ...] | None = None, + channel_wised_quantize: bool = False, algo_config: WeightOnlyQuantConfig | None = None, ): if nodes_to_exclude is None: @@ -1158,6 +1212,7 @@ def __init__( op_types_to_quantize=op_types_to_quantize, quant_axes=quant_axes, bits=4, # default to 4 bits + channel_wised_quantize=channel_wised_quantize, ) self.algo_config = algo_config