Skip to content

[Bug] RemoveUnusedOutputs give unexpected results #17247

@Cookiee235

Description

@Cookiee235

Hi all, The pass RemoveUnusedOutputs seems to give an unexpected optimized result. Due to the lack of detailed documentation about this API (e.g., relax.transform.RemoveUnusedOutputs), I cannot confirm if the optimization result is wrong.

In addition, another bug is about the API tvm.ir.assert_structural_equal, for the totally same mod, this API judge the structure of them as unequal. It was triggered by IRs with the string "nan".

Actual behavior

## Output IRs after the RemoveUnusedOutputs
@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
            R.output(res)
        return res
----------------------------------------------------------------------------------------------------------------------------------
Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/assert_structure.py", line 66, in <module>
    tvm.ir.assert_structural_equal(mod, mod)
  File "/software/tvm-lunder/python/tvm/ir/base.py", line 256, in assert_structural_equal
    _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)  # type: ignore # pylint: disable=no-member
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  5: _ZN3tvm7runtime13PackedFuncObj
  4: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}>(tvm::{lambda(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, bool)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  3: tvm::SEqualHandlerDefault::Impl::Equal(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool)
  2: tvm::SEqualHandlerDefault::Impl::RunTasks()
  1: tvm::SEqualHandlerDefault::DispatchSEqualReduce(tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, bool, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  0: tvm::SEqualHandlerDefault::Impl::CheckResult(bool, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef const&, tvm::runtime::Optional<tvm::ObjectPathPair> const&)
  File "/software/tvm-lunder/src/node/structural_equal.cc", line 392
ValueError: StructuralEqual check failed, caused by lhs at <root>.functions[I.GlobalVar("main")].body.blocks[0].bindings[0].value.fields[0].value.value:
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
                                                                                                                                                  ^^^^^
            R.output(res)
        return res
and rhs at <root>.functions[I.GlobalVar("main")].body.blocks[0].bindings[0].value.fields[0].value.value:
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))):
        R.func_attr({"num_input": 2})
        with R.dataflow():
            res: R.Tuple(R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan")), R.Prim(value=T.float64("nan"))) = R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan")), R.prim_value(T.float64("nan"))
                                                                                                                                                  ^^^^^
            R.output(res)
        return res

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def ones(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 1

    @T.prim_func(private=True)
    def zeros(T_full: T.Buffer((T.int64(16), T.int64(16)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 0

    @T.prim_func(private=True)
    def zeros1(T_full: T.Buffer((T.int64(32), T.int64(32)), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(32), T.int64(32)):
            with T.block("T_full"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads()
                T.writes(T_full[v_ax0, v_ax1])
                T_full[v_ax0, v_ax1] = 0

    @R.function(private=True)
    def func() -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")):
        cls = Module
        A = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
        B = R.call_tir(cls.ones, R.tuple(), out_sinfo=R.Tensor((16, 16), dtype="int32"))
        C = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((32, 32), dtype="int32"))
        return (A, B, C)

    @R.function
    def main(v0_0: R.Tensor((1,), dtype="int32"), v1_0: R.Tensor((42,), dtype="int32")) -> R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")):
        R.func_attr({"num_input": 2})
        cls = Module
        with R.dataflow():
            res: R.Tuple(R.Tensor((16, 16), dtype="int32"), R.Tensor((16, 16), dtype="int32"), R.Tensor((32, 32), dtype="int32")) = cls.func()
            R.output(res)
        return res


mod = Module
mod.show()

mod = relax.transform.RemoveUnusedOutputs()(mod)
mod.show()  # is this irs correct?
tvm.ir.assert_structural_equal(mod, mod)  # not equal! why?

cc @Lunderberg @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions