Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from . import _ffi_api
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ..struct_info import StructInfo, TensorStructInfo, TupleStructInfo
from ...ir import PrimExpr
from ..utils import args_converter

Expand Down Expand Up @@ -97,7 +97,9 @@ def call_tir(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
if isinstance(args, Expr) and not (
isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo)
): # type: ignore
args = RxTuple((args,))

if not isinstance(out_sinfo, list):
Expand Down Expand Up @@ -152,7 +154,9 @@ def call_tir_with_grad(
ret: Call
A call node for the call_tir_with_grad operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
if isinstance(args, Expr) and not (
isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo)
): # type: ignore
args = RxTuple((args,))

if not isinstance(out_sinfo, list):
Expand Down Expand Up @@ -220,7 +224,9 @@ def call_tir_inplace(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
if isinstance(args, Expr) and not (
isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo)
): # type: ignore
args = RxTuple((args,))

if not isinstance(inplace_indices, list):
Expand Down
31 changes: 15 additions & 16 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ RELAY_REGISTER_OP("relax.call_tir")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
Expr MakeCallTIR(const Expr& func, const Expr& args, const Array<TensorStructInfo>& out_sinfo_list,
Optional<Expr> packed_ints) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
Expand Down Expand Up @@ -307,9 +307,9 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
String te_grad_name, Map<String, ObjectRef> te_grad_kwargs,
Optional<Expr> packed_ints) {
Expr MakeCallTIRWithGrad(const Expr& func, const Expr& args,
const Array<TensorStructInfo>& out_sinfo_list, String te_grad_name,
Map<String, ObjectRef> te_grad_kwargs, Optional<Expr> packed_ints) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
CHECK(shape != nullptr)
Expand Down Expand Up @@ -364,7 +364,8 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c
}

// check the range for inplace indices, make sure at least one is not -1, ensure they're unique
size_t num_args = Downcast<Tuple>(call->args[1])->fields.size();
auto arg_sinfo = GetStructInfoAs<TupleStructInfoNode>(call->args[1]);
size_t num_args = arg_sinfo->fields.size();
std::unordered_set<int> encountered;
for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
int index = attrs->inplace_indices[i].IntValue();
Expand All @@ -391,14 +392,13 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c
// for safety, we will make sure the output shape for each in-place argument exactly matches the
// input shape
// TODO(@slyubomirsky): eventually we will want to handle cases where that is not true
Tuple call_args = Downcast<Tuple>(call->args[1]);
if (attrs->inplace_indices.size() == 1) {
auto* out_sinfo = call->sinfo_args[0].as<TensorStructInfoNode>();
if (!out_sinfo) {
ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor");
}
auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(
call_args->fields[attrs->inplace_indices[0].IntValue()]);
auto* input_sinfo =
arg_sinfo->fields[attrs->inplace_indices[0].IntValue()].as<TensorStructInfoNode>();
if (!input_sinfo || !input_sinfo->shape.defined() ||
!CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(),
ctx->GetAnalyzer())) {
Expand All @@ -412,24 +412,23 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c
} else {
auto out_sinfos = call->sinfo_args[0].as<TupleStructInfoNode>()->fields;
for (size_t i = 0; i < attrs->inplace_indices.size(); i++) {
if (attrs->inplace_indices[i].IntValue() == -1) {
int inplace_index = attrs->inplace_indices[i].IntValue();
if (inplace_index == -1) {
continue;
}
auto* out_sinfo = out_sinfos[i].as<TensorStructInfoNode>();
if (!out_sinfo) {
ctx->ReportFatal(Diagnostic::Error(call) << "The output struct info must be a tensor");
}
auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(
call_args->fields[attrs->inplace_indices[i].IntValue()]);
auto* input_sinfo = arg_sinfo->fields[inplace_index].as<TensorStructInfoNode>();
if (!input_sinfo || !input_sinfo->shape.defined() ||
!CanProveShapeEqual(input_sinfo->shape.value(), out_sinfo->shape.value(),
ctx->GetAnalyzer())) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "The shape of output " << i << " must match that of input "
<< attrs->inplace_indices[i].IntValue() << ", whereas we have "
<< out_sinfo->shape.value() << " in output " << i << " versus "
<< input_sinfo->shape.value() << " in input "
<< attrs->inplace_indices[i].IntValue());
<< inplace_index << ", whereas we have " << out_sinfo->shape.value()
<< " in output " << i << " versus " << input_sinfo->shape.value()
<< " in input " << attrs->inplace_indices[i].IntValue());
}
}
}
Expand All @@ -453,7 +452,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace")
// arguments will no longer be live)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIRInplace(Expr func, Tuple args, Array<Integer> inplace_indices,
Expr MakeCallTIRInplace(const Expr& func, const Expr& args, const Array<Integer>& inplace_indices,
Array<TensorStructInfo> out_sinfo_list, Optional<Expr> packed_ints) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
Expand Down
24 changes: 19 additions & 5 deletions src/relax/transform/call_tir_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class CallTIRMutator : public ExprMutator {
<< "If calling call_tir_inplace and there is one output, its in-place index must not"
" be -1.";
outs.push_back(
Downcast<Tuple>(call->args[1])->fields[inplace_attrs->inplace_indices[0].IntValue()]);
GetTupleIndex(call->args[1], inplace_attrs->inplace_indices[0].IntValue()));
}
} else if (const auto& _tuple_sinfo = MatchStructInfo<TupleStructInfo>(expr)) {
// multiple output case
Expand All @@ -101,8 +101,8 @@ class CallTIRMutator : public ExprMutator {
Attrs()),
"alloc"));
} else {
outs.push_back(Downcast<Tuple>(call->args[1])
->fields[inplace_attrs->inplace_indices[i].IntValue()]);
outs.push_back(
GetTupleIndex(call->args[1], inplace_attrs->inplace_indices[i].IntValue()));
}
}
} else {
Expand All @@ -112,8 +112,10 @@ class CallTIRMutator : public ExprMutator {
}

Array<Expr> args;
if (call->args[1].as<TupleNode>()) {
args = Downcast<Tuple>(call->args[1])->fields;
if (const auto* tuple_info = GetStructInfoAs<TupleStructInfoNode>(call->args[1])) {
for (size_t i = 0; i < tuple_info->fields.size(); i++) {
args.push_back(GetTupleIndex(call->args[1], i));
}
// for call_tir_inplace, don't reinsert in-place args, only the newly allocated ones
if (!is_inplace) {
args.insert(args.end(), outs.begin(), outs.end());
Expand Down Expand Up @@ -150,6 +152,18 @@ class CallTIRMutator : public ExprMutator {

return GetRef<Expr>(call);
}

private:
// If e is a tuple literal, return the field denoted by the index.
// Otherwise, insert a tuple get item for that field and return the
// var the result is bound to.
Expr GetTupleIndex(const Expr& e, int index) {
if (const auto* tuple_node = e.as<TupleNode>()) {
return tuple_node->fields[index];
}
auto out = builder_->Emit(TupleGetItem(e, index));
return out;
}
};

Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }
Expand Down
58 changes: 58 additions & 0 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,64 @@ def foo(x: R.Tensor(("m", "n"), "float32")):
assert s2.op.name_hint == "exp"


def test_call_tir_rewrite_separate_tuple():
# if the arguments to call_tir are tuple-typed but not a tuple literal,
# the rewrite should index into the tuple
@tvm.script.ir_module
class TestCallTIRRewrite:
@T.prim_func
def add(
A: T.Buffer((2, 3), "float32"),
B: T.Buffer((2, 3), "float32"),
C: T.Buffer((2, 3), "float32"),
):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(A[ax0, ax1], B[ax0, ax1])
T.writes(C[ax0, ax1])
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]

@R.function
def foo(x: R.Tensor((2, 3), "float32")):
# we expect RemovePurityChecking to have been used before this point
R.func_attr({"relax.force_pure": True})
tup = (x, x)
gv0 = R.call_tir(TestCallTIRRewrite.add, tup, R.Tensor((2, 3), dtype="float32"))
return gv0

@tvm.script.ir_module
class Expected:
@T.prim_func
def add(
A: T.Buffer((2, 3), "float32"),
B: T.Buffer((2, 3), "float32"),
C: T.Buffer((2, 3), "float32"),
):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(A[ax0, ax1], B[ax0, ax1])
T.writes(C[ax0, ax1])
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]

@R.function
def foo(x: R.Tensor((2, 3), "float32")):
R.func_attr({"relax.force_pure": True})
tup = (x, x)
alloc = R.builtin.alloc_tensor(R.shape([2, 3]), dtype="float32", runtime_device_index=0)
v0 = tup[0]
v1 = tup[1]
_ = Expected.add(v0, v1, alloc)
gv0 = alloc
return gv0

new_mod = relax.transform.CallTIRRewrite()(TestCallTIRRewrite)
assert structural_equal(new_mod, Expected)


def test_transform_remove_purity_checking():
@tvm.script.ir_module
class Before:
Expand Down
49 changes: 47 additions & 2 deletions tests/python/relax/test_vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,11 @@ def main(
) -> R.Tuple(
R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")
):
# also make sure it works with a tuple bound separately
tup = (x, y, z)
res = R.call_tir_inplace(
TestCallTIRInplaceE2ESimple.copy,
(x, y, z),
tup,
[0, 1, -1],
[R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")],
)
Expand Down Expand Up @@ -305,6 +307,48 @@ def main(
tvm.testing.assert_allclose(out.numpy(), expected.numpy(), rtol=1e-7, atol=1e-7)


@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_call_tir_reuse_tuple_input(exec_mode):
# read and write from the same tensor
@tvm.script.ir_module
class TestCallTIRTupleInput:
@T.prim_func
def add(
A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32")
):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(A[ax0, ax1], B[ax0, ax1])
T.writes(C[ax0, ax1])
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]

@R.function
def main(
x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")
) -> R.Tuple(R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32")):
tup = (x, y)
res1 = R.call_tir(TestCallTIRTupleInput.add, tup, out_sinfo=R.Tensor((2, 3), "int32"))
res2 = R.call_tir(TestCallTIRTupleInput.add, tup, out_sinfo=R.Tensor((2, 3), "int32"))
return (res1, res2)

mod = TestCallTIRTupleInput

target = tvm.target.Target("llvm", host="llvm")
ex = relax.build(mod, target, exec_mode=exec_mode)
vm = relax.VirtualMachine(ex, tvm.cpu())

x = tvm.nd.array(np.ones((2, 3)).astype(np.int32))
y = tvm.nd.array(np.ones((2, 3)).astype(np.int32))
vm.set_input("main", x, y)
vm.invoke_stateful("main")
out = vm.get_outputs("main")
expected = tvm.nd.array(np.full((2, 3), 2).astype(np.int32))
tvm.testing.assert_allclose(out[0].numpy(), expected.numpy(), rtol=1e-7, atol=1e-7)
tvm.testing.assert_allclose(out[1].numpy(), expected.numpy(), rtol=1e-7, atol=1e-7)


@pytest.mark.parametrize("exec_mode", EXEC_MODE)
def test_vm_emit_te_extern(exec_mode):
if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
Expand Down Expand Up @@ -661,7 +705,8 @@ def relax_matmul_tir(
) -> R.Tensor((32, 32), dtype="float32"):
cls = TestVMSubFunction
with R.dataflow():
gv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32"))
tup = (x, w)
gv0 = R.call_tir(cls.tir_matmul, tup, R.Tensor((32, 32), dtype="float32"))
R.output(gv0)
return gv0

Expand Down