Skip to content

Commit c8deb7f

Browse files
authored
Overriding the StructuralEqual() for easy usage (#16908)
* Overriding the Structural Equal() for easy usage * lint error fixed * fixing white space lint error * whitespace lint error
1 parent 114ad70 commit c8deb7f

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

include/tvm/node/structural_equal.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,11 @@ class StructuralEqual : public BaseValueEqual {
108108
* \brief Compare objects via strutural equal.
109109
* \param lhs The left operand.
110110
* \param rhs The right operand.
111+
* \param map_free_params Whether or not to map free variables.
111112
* \return The comparison result.
112113
*/
113-
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
114+
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs,
115+
const bool map_free_params = false) const;
114116
};
115117

116118
/*!

src/node/structural_equal.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,9 @@ TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch")
563563
return first_mismatch;
564564
});
565565

566-
bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
567-
return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, false);
566+
bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs,
567+
bool map_free_params) const {
568+
return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, map_free_params);
568569
}
569570

570571
bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs,

0 commit comments

Comments
 (0)