diff --git a/support_status.md b/support_status.md index f4f736539..bf3baeb7b 100644 --- a/support_status.md +++ b/support_status.md @@ -27,6 +27,7 @@ | AvgPool3D | 1 ~ 17 | | BatchMatMul | 1 ~ 17 | | BatchMatMulV2 | 1 ~ 17 | +| BatchMatMulV3 | 1 ~ 17 | | BatchToSpaceND | 1 ~ 17 | | BiasAdd | 1 ~ 17 | | BiasAddV1 | 1 ~ 17 | diff --git a/tests/test_backend.py b/tests/test_backend.py index c2d5960ec..b9674a63b 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1075,6 +1075,15 @@ def func(x, y): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_val}, rtol=1e-5) + @check_tf_min_version("2.6") + def test_matmulinteger(self): + x_val = np.array([1, 2, -3, -4], dtype=np.int8).reshape((2, 2)) + y_val = np.array([1, 2, -3, -4], dtype=np.int8).reshape((2, 2)) + def func(x, y): + x_ = tf.matmul(x, y, output_type=tf.int32) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}) + @check_onnxruntime_incompatibility("Sub") def test_sub(self): x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2)) diff --git a/tf2onnx/onnx_opset/math.py b/tf2onnx/onnx_opset/math.py index bdebd1fa6..719a5114c 100644 --- a/tf2onnx/onnx_opset/math.py +++ b/tf2onnx/onnx_opset/math.py @@ -363,13 +363,13 @@ def version_1(cls, ctx, node, **kwargs): name=op_name, shapes=shapes, dtypes=dtypes) -@tf_op(["MatMul", "BatchMatMul", "BatchMatMulV2"]) +@tf_op(["MatMul", "BatchMatMul", "BatchMatMulV2", "BatchMatMulV3"]) class MatMul: @classmethod def version_1(cls, ctx, node, **kwargs): # tensorflow allows transpose and conjugated. If found, insert the required transpose. # We could use Gemm as well but tensorflow does not pass bias in matmul. - node.type = "MatMul" + if node.type != "MatMulInteger": node.type = "MatMul" attrs = ["transpose_a", "transpose_b", "adjoint_a", "adjoint_b", "adj_x", "adj_y"] attrs_val = [node.get_attr(attr) for attr in attrs] @@ -408,7 +408,19 @@ def version_1(cls, ctx, node, **kwargs): val = node.get_attr(i) if val is not None and val.i != 0: raise ValueError(node.type + " attribute " + i + " is not supported") - + @classmethod + def version_10(cls, ctx, node, **kwargs): + if (ctx.get_dtype(node.input[0]) in [onnx_pb.TensorProto.INT8, onnx_pb.TensorProto.UINT8] and + ctx.get_dtype(node.input[1]) in [onnx_pb.TensorProto.INT8, onnx_pb.TensorProto.UINT8] and + ctx.get_dtype(node.output[0]) == onnx_pb.TensorProto.INT32): + node.type = "MatMulInteger" + zpdata_a = np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0]))) + zero_point_node_a = ctx.make_const(utils.make_name("zero_point_a"), zpdata_a) + zpdata_b = np.zeros(1, dtype=utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))) + zero_point_node_b = ctx.make_const(utils.make_name("zero_point_b"), zpdata_b) + ctx.replace_inputs(node, [node.input[0], node.input[1], + zero_point_node_a.output[0], zero_point_node_b.output[0]]) + cls.version_1(ctx, node, **kwargs) @tf_op("Erf") class Erf: