Skip to content

Commit bd02406

Browse files
committed
Disable InferType if it was done and no changes after previous pass
This optimizatin allows to speedup PatternRewriter transformations by reusing of preious type inferred expression instead of perform InferType multiple times
1 parent 567eeed commit bd02406

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/relay/ir/dataflow_matcher.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -851,24 +851,32 @@ Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const E
851851
std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual> done;
852852
do {
853853
last = post;
854+
// We don't have to call InferType if previous pass has not modified anything
855+
// We can just take previous typed state of the expression
856+
bool types_invalidated = true;
854857
for (auto callback : callbacks) {
855858
if (!done[callback]) {
856859
auto before = post;
860+
auto post_typed = post;
857861
callback_ = callback;
858-
if (callback_->require_type) {
859-
post = InferTypeWithModule(post, mod_);
862+
if (callback_->require_type && types_invalidated) {
863+
post_typed = InferTypeWithModule(post, mod_);
860864
}
861865
auto grouper = PatternGrouper();
862-
groups_ = grouper.GroupMatches(callback_->pattern, post);
866+
groups_ = grouper.GroupMatches(callback_->pattern, post_typed);
863867
gid_assignments_ = grouper.GetGIDAssignments();
864868
memo_.clear();
865869
VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
866-
post = this->VisitExpr(post);
870+
post = this->VisitExpr(post_typed);
867871
VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
868872
count++;
869-
if (callback_->rewrite_once) {
870-
bool current_equal = (*structural_equal)(before, post, false, true);
871-
if (!current_equal) {
873+
bool current_equal = (*structural_equal)(before, post, false, true);
874+
if (callback_->require_type && current_equal) {
875+
types_invalidated = false;
876+
post = post_typed;
877+
} else {
878+
types_invalidated = true;
879+
if (callback_->rewrite_once) {
872880
done[callback] = true;
873881
}
874882
}

0 commit comments

Comments
 (0)