diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index b363dc6952d8..1472f2f6fb9d 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -25,7 +25,7 @@ from . import _ffi_api from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var from ..expr import Tuple as RxTuple -from ..struct_info import StructInfo, TensorStructInfo +from ..struct_info import StructInfo, TensorStructInfo, TupleStructInfo from ...ir import PrimExpr from ..utils import args_converter @@ -97,7 +97,9 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + if isinstance(args, Expr) and not ( + isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo) + ): # type: ignore args = RxTuple((args,)) if not isinstance(out_sinfo, list): @@ -152,7 +154,9 @@ def call_tir_with_grad( ret: Call A call node for the call_tir_with_grad operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + if isinstance(args, Expr) and not ( + isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo) + ): # type: ignore args = RxTuple((args,)) if not isinstance(out_sinfo, list): @@ -220,7 +224,9 @@ def call_tir_inplace( ret: Call A call node for the call_tir operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + if isinstance(args, Expr) and not ( + isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo) + ): # type: ignore args = RxTuple((args,)) if not isinstance(inplace_indices, list): diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 01d0d04be0cc..eb86e513076f 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -263,7 +263,7 @@ RELAY_REGISTER_OP("relax.call_tir") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, +Expr MakeCallTIR(const Expr& func, const Expr& args, const Array& out_sinfo_list, Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); @@ -307,9 +307,9 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinfo_list, - String te_grad_name, Map te_grad_kwargs, - Optional packed_ints) { +Expr MakeCallTIRWithGrad(const Expr& func, const Expr& args, + const Array& out_sinfo_list, String te_grad_name, + Map te_grad_kwargs, Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); CHECK(shape != nullptr) @@ -364,7 +364,8 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c } // check the range for inplace indices, make sure at least one is not -1, ensure they're unique - size_t num_args = Downcast(call->args[1])->fields.size(); + auto arg_sinfo = GetStructInfoAs(call->args[1]); + size_t num_args = arg_sinfo->fields.size(); std::unordered_set encountered; for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { int index = attrs->inplace_indices[i].IntValue(); @@ -391,14 +392,13 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c // for safety, we will make sure the output shape for each in-place argument exactly matches the // input shape // TODO(@slyubomirsky): eventually we will want to handle cases where that is not true - Tuple call_args = Downcast(call->args[1]); if (attrs->inplace_indices.size() == 1) { auto* out_sinfo = call->sinfo_args[0].as(); if (!out_sinfo) { ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[0].IntValue()]); + auto* input_sinfo = + arg_sinfo->fields[attrs->inplace_indices[0].IntValue()].as(); if (!input_sinfo || !input_sinfo->shape.defined() || !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), ctx->GetAnalyzer())) { @@ -412,24 +412,23 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c } else { auto out_sinfos = call->sinfo_args[0].as()->fields; for (size_t i = 0; i < attrs->inplace_indices.size(); i++) { - if (attrs->inplace_indices[i].IntValue() == -1) { + int inplace_index = attrs->inplace_indices[i].IntValue(); + if (inplace_index == -1) { continue; } auto* out_sinfo = out_sinfos[i].as(); if (!out_sinfo) { ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor"); } - auto* input_sinfo = GetStructInfoAs( - call_args->fields[attrs->inplace_indices[i].IntValue()]); + auto* input_sinfo = arg_sinfo->fields[inplace_index].as(); if (!input_sinfo || !input_sinfo->shape.defined() || !CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(), ctx->GetAnalyzer())) { ctx->ReportFatal(Diagnostic::Error(call) << "The shape of output " << i << " must match that of input " - << attrs->inplace_indices[i].IntValue() << ", whereas we have " - << out_sinfo->shape.value() << " in output " << i << " versus " - << input_sinfo->shape.value() << " in input " - << attrs->inplace_indices[i].IntValue()); + << inplace_index << ", whereas we have " << out_sinfo->shape.value() + << " in output " << i << " versus " << input_sinfo->shape.value() + << " in input " << attrs->inplace_indices[i].IntValue()); } } } @@ -453,7 +452,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace") // arguments will no longer be live) .set_attr("FPurity", Bool(true)); -Expr MakeCallTIRInplace(Expr func, Tuple args, Array inplace_indices, +Expr MakeCallTIRInplace(const Expr& func, const Expr& args, const Array& inplace_indices, Array out_sinfo_list, Optional packed_ints) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index e040ccea1485..0b3f661fc774 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -78,7 +78,7 @@ class CallTIRMutator : public ExprMutator { << "If calling call_tir_inplace and there is one output, its in-place index must not" " be -1."; outs.push_back( - Downcast(call->args[1])->fields[inplace_attrs->inplace_indices[0].IntValue()]); + GetTupleIndex(call->args[1], inplace_attrs->inplace_indices[0].IntValue())); } } else if (const auto& _tuple_sinfo = MatchStructInfo(expr)) { // multiple output case @@ -101,8 +101,8 @@ class CallTIRMutator : public ExprMutator { Attrs()), "alloc")); } else { - outs.push_back(Downcast(call->args[1]) - ->fields[inplace_attrs->inplace_indices[i].IntValue()]); + outs.push_back( + GetTupleIndex(call->args[1], inplace_attrs->inplace_indices[i].IntValue())); } } } else { @@ -112,8 +112,10 @@ class CallTIRMutator : public ExprMutator { } Array args; - if (call->args[1].as()) { - args = Downcast(call->args[1])->fields; + if (const auto* tuple_info = GetStructInfoAs(call->args[1])) { + for (size_t i = 0; i < tuple_info->fields.size(); i++) { + args.push_back(GetTupleIndex(call->args[1], i)); + } // for call_tir_inplace, don't reinsert in-place args, only the newly allocated ones if (!is_inplace) { args.insert(args.end(), outs.begin(), outs.end()); @@ -150,6 +152,18 @@ class CallTIRMutator : public ExprMutator { return GetRef(call); } + + private: + // If e is a tuple literal, return the field denoted by the index. + // Otherwise, insert a tuple get item for that field and return the + // var the result is bound to. + Expr GetTupleIndex(const Expr& e, int index) { + if (const auto* tuple_node = e.as()) { + return tuple_node->fields[index]; + } + auto out = builder_->Emit(TupleGetItem(e, index)); + return out; + } }; Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); } diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 9ab2ffc60536..475f3c847a93 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -123,6 +123,64 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert s2.op.name_hint == "exp" +def test_call_tir_rewrite_separate_tuple(): + # if the arguments to call_tir are tuple-typed but not a tuple literal, + # the rewrite should index into the tuple + @tvm.script.ir_module + class TestCallTIRRewrite: + @T.prim_func + def add( + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + C: T.Buffer((2, 3), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax0, ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] + + @R.function + def foo(x: R.Tensor((2, 3), "float32")): + # we expect RemovePurityChecking to have been used before this point + R.func_attr({"relax.force_pure": True}) + tup = (x, x) + gv0 = R.call_tir(TestCallTIRRewrite.add, tup, R.Tensor((2, 3), dtype="float32")) + return gv0 + + @tvm.script.ir_module + class Expected: + @T.prim_func + def add( + A: T.Buffer((2, 3), "float32"), + B: T.Buffer((2, 3), "float32"), + C: T.Buffer((2, 3), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax0, ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] + + @R.function + def foo(x: R.Tensor((2, 3), "float32")): + R.func_attr({"relax.force_pure": True}) + tup = (x, x) + alloc = R.builtin.alloc_tensor(R.shape([2, 3]), dtype="float32", runtime_device_index=0) + v0 = tup[0] + v1 = tup[1] + _ = Expected.add(v0, v1, alloc) + gv0 = alloc + return gv0 + + new_mod = relax.transform.CallTIRRewrite()(TestCallTIRRewrite) + assert structural_equal(new_mod, Expected) + + def test_transform_remove_purity_checking(): @tvm.script.ir_module class Before: diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 82a6d6a2a4d1..4ef93e578020 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -232,9 +232,11 @@ def main( ) -> R.Tuple( R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32") ): + # also make sure it works with a tuple bound separately + tup = (x, y, z) res = R.call_tir_inplace( TestCallTIRInplaceE2ESimple.copy, - (x, y, z), + tup, [0, 1, -1], [R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")], ) @@ -305,6 +307,48 @@ def main( tvm.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_call_tir_reuse_tuple_input(exec_mode): + # read and write from the same tensor + @tvm.script.ir_module + class TestCallTIRTupleInput: + @T.prim_func + def add( + A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32") + ): + T.func_attr({"tir.noalias": True}) + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("T_add"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[ax0, ax1], B[ax0, ax1]) + T.writes(C[ax0, ax1]) + C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] + + @R.function + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")): + tup = (x, y) + res1 = R.call_tir(TestCallTIRTupleInput.add, tup, out_sinfo=R.Tensor((2, 3), "int32")) + res2 = R.call_tir(TestCallTIRTupleInput.add, tup, out_sinfo=R.Tensor((2, 3), "int32")) + return (res1, res2) + + mod = TestCallTIRTupleInput + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + x = tvm.nd.array(np.ones((2, 3)).astype(np.int32)) + y = tvm.nd.array(np.ones((2, 3)).astype(np.int32)) + vm.set_input("main", x, y) + vm.invoke_stateful("main") + out = vm.get_outputs("main") + expected = tvm.nd.array(np.full((2, 3), 2).astype(np.int32)) + tvm.testing.assert_allclose(out[0].numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(out[1].numpy(), expected.numpy(), rtol=1e-7, atol=1e-7) + + @pytest.mark.parametrize("exec_mode", EXEC_MODE) def test_vm_emit_te_extern(exec_mode): if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): @@ -661,7 +705,8 @@ def relax_matmul_tir( ) -> R.Tensor((32, 32), dtype="float32"): cls = TestVMSubFunction with R.dataflow(): - gv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + tup = (x, w) + gv0 = R.call_tir(cls.tir_matmul, tup, R.Tensor((32, 32), dtype="float32")) R.output(gv0) return gv0