Skip to content

[mlir][IR] Trigger notifyOperationRemoved callback for nested ops #66771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

When cloning an op, the notifyOperationInserted callback is triggered for all nested ops. Similarly, the notifyOperationRemoved callback should be triggered for all nested ops when removing an op.

Listeners may inspect the IR during a notifyOperationRemoved callback. Therefore, when multiple ops are removed in a single RewriterBase::eraseOp call, the notifications must be triggered in an order in which the ops could have been removed one-by-one:

  • Op removals must be interleaved with notifyOperationRemoved callbacks. A callback is triggered right before the respective op is removed.
  • Ops are removed post-order and in reverse order. Other traversal orders could delete an op that still has uses. (This is not avoidable in graph regions and with cyclic block graphs.)

Note: Imported from https://reviews.llvm.org/D144193.

When cloning an op, the `notifyOperationInserted` callback is triggered for all nested ops. Similarly, the `notifyOperationRemoved` callback should be triggered for all nested ops when removing an op.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:bufferization Bufferization infrastructure labels Sep 19, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2023

@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Changes

When cloning an op, the notifyOperationInserted callback is triggered for all nested ops. Similarly, the notifyOperationRemoved callback should be triggered for all nested ops when removing an op.

Listeners may inspect the IR during a notifyOperationRemoved callback. Therefore, when multiple ops are removed in a single RewriterBase::eraseOp call, the notifications must be triggered in an order in which the ops could have been removed one-by-one:

  • Op removals must be interleaved with notifyOperationRemoved callbacks. A callback is triggered right before the respective op is removed.
  • Ops are removed post-order and in reverse order. Other traversal orders could delete an op that still has uses. (This is not avoidable in graph regions and with cyclic block graphs.)

Note: Imported from https://reviews.llvm.org/D144193.


Full diff: https://github.com/llvm/llvm-project/pull/66771.diff

7 Files Affected:

  • (modified) mlir/include/mlir/IR/RegionKindInterface.h (+6)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+3-6)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+68-4)
  • (modified) mlir/lib/IR/RegionKindInterface.cpp (+10-2)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+3-6)
  • (modified) mlir/test/Transforms/test-strict-pattern-driver.mlir (+153-7)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+8)
diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h
index 46bfe717533a84a..d6d3aeeb9bd0526 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.h
+++ b/mlir/include/mlir/IR/RegionKindInterface.h
@@ -43,6 +43,12 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
 /// not implement the RegionKindInterface.
 bool mayHaveSSADominance(Region &region);
 
+/// Return "true" if the given region may be a graph region without SSA
+/// dominance. This function returns "true" in case the owner op is an
+/// unregistered op. It returns "false" if it is a registered op that does not
+/// implement the RegionKindInterface.
+bool mayBeGraphRegion(Region &region);
+
 } // namespace mlir
 
 #include "mlir/IR/RegionKindInterface.h.inc"
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index cad78b3e65b2313..c34f422292cb4f0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -394,12 +394,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
 
 protected:
   void notifyOperationRemoved(Operation *op) override {
-    // TODO: Walk can be removed when D144193 has landed.
-    op->walk([&](Operation *op) {
-      erasedOps.insert(op);
-      // Erase if present.
-      toMemrefOps.erase(op);
-    });
+    erasedOps.insert(op);
+    // Erase if present.
+    toMemrefOps.erase(op);
   }
 
   void notifyOperationInserted(Operation *op) override {
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index db920c14ea08dc7..5e9b9b2a810a4c5 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/RegionKindInterface.h"
 
 using namespace mlir;
 
@@ -275,7 +277,7 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   for (auto it : llvm::zip(op->getResults(), newValues))
     replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
-  // Erase the op.
+  // Erase op and notify listener.
   eraseOp(op);
 }
 
@@ -295,7 +297,7 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
   for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
     replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
-  // Erase the old op.
+  // Erase op and notify listener.
   eraseOp(op);
 }
 
@@ -303,9 +305,71 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
 /// the given operation *must* be known to be dead.
 void RewriterBase::eraseOp(Operation *op) {
   assert(op->use_empty() && "expected 'op' to have no uses");
-  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+  auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
+
+  // Fast path: If no listener is attached, the op can be dropped in one go.
+  if (!rewriteListener) {
+    op->erase();
+    return;
+  }
+
+  // Helper function that erases a single op.
+  auto eraseSingleOp = [&](Operation *op) {
+#ifndef NDEBUG
+    // All nested ops should have been erased already.
+    assert(
+        llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
+        "expected empty regions");
+    // All users should have been erased already if the op is in a region with
+    // SSA dominance.
+    if (!op->use_empty() && op->getParentOp())
+      assert(mayBeGraphRegion(*op->getParentRegion()) &&
+             "expected that op has no uses");
+#endif // NDEBUG
     rewriteListener->notifyOperationRemoved(op);
-  op->erase();
+
+    // Explicitly drop all uses in case the op is in a graph region.
+    op->dropAllUses();
+    op->erase();
+  };
+
+  // Nested ops must be erased one-by-one, so that listeners have a consistent
+  // view of the IR every time a notification is triggered. Users must be
+  // erased before definitions. I.e., post-order, reverse dominance.
+  std::function<void(Operation *)> eraseTree = [&](Operation *op) {
+    // Erase nested ops.
+    for (Region &r : llvm::reverse(op->getRegions())) {
+      // Erase all blocks in the right order. Successors should be erased
+      // before predecessors because successor blocks may use values defined
+      // in predecessor blocks. A post-order traversal of blocks within a
+      // region visits successors before predecessors. Repeat the traversal
+      // until the region is empty. (The block graph could be disconnected.)
+      while (!r.empty()) {
+        SmallVector<Block *> erasedBlocks;
+        for (Block *b : llvm::post_order(&r.front())) {
+          // Visit ops in reverse order.
+          for (Operation &op :
+               llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
+            eraseTree(&op);
+          // Do not erase the block immediately. This is not supprted by the
+          // post_order iterator.
+          erasedBlocks.push_back(b);
+        }
+        for (Block *b : erasedBlocks) {
+          // Explicitly drop all uses in case there is a cycle in the block
+          // graph.
+          for (BlockArgument bbArg : b->getArguments())
+            bbArg.dropAllUses();
+          b->dropAllUses();
+          b->erase();
+        }
+      }
+    }
+    // Then erase the enclosing op.
+    eraseSingleOp(op);
+  };
+
+  eraseTree(op);
 }
 
 void RewriterBase::eraseBlock(Block *block) {
diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp
index cbef3025a5dd626..007f4cf92dbc7ae 100644
--- a/mlir/lib/IR/RegionKindInterface.cpp
+++ b/mlir/lib/IR/RegionKindInterface.cpp
@@ -18,9 +18,17 @@ using namespace mlir;
 #include "mlir/IR/RegionKindInterface.cpp.inc"
 
 bool mlir::mayHaveSSADominance(Region &region) {
-  auto regionKindOp =
-      dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
+  auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
   if (!regionKindOp)
     return true;
   return regionKindOp.hasSSADominance(region.getRegionNumber());
 }
+
+bool mlir::mayBeGraphRegion(Region &region) {
+  if (!region.getParentOp()->isRegistered())
+    return true;
+  auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
+  if (!regionKindOp)
+    return false;
+  return !regionKindOp.hasSSADominance(region.getRegionNumber());
+}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index fba4944f130c230..8e2bfe557c555f3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -421,8 +421,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 
     // If the operation is trivially dead - remove it.
     if (isOpTriviallyDead(op)) {
-      notifyOperationRemoved(op);
-      op->erase();
+      eraseOp(op);
       changed = true;
 
       LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -567,10 +566,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
     config.listener->notifyOperationRemoved(op);
 
   addOperandsToWorklist(op->getOperands());
-  op->walk([this](Operation *operation) {
-    worklist.remove(operation);
-    folder.notifyRemoval(operation);
-  });
+  worklist.remove(op);
+  folder.notifyRemoval(op);
 
   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
     strictModeFilteredOps.erase(op);
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 5df2d6d1fdeeb38..a5ab8f97c74ce33 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -12,9 +12,9 @@
 
 // CHECK-EN-LABEL: func @test_erase
 //  CHECK-EN-SAME:     pattern_driver_all_erased = true, pattern_driver_changed = true}
-//       CHECK-EN:   test.arg0
-//       CHECK-EN:   test.arg1
-//   CHECK-EN-NOT:   test.erase_op
+//       CHECK-EN:   "test.arg0"
+//       CHECK-EN:   "test.arg1"
+//   CHECK-EN-NOT:   "test.erase_op"
 func.func @test_erase() {
   %0 = "test.arg0"() : () -> (i32)
   %1 = "test.arg1"() : () -> (i32)
@@ -51,13 +51,13 @@ func.func @test_replace_with_new_op() {
 
 // CHECK-EN-LABEL: func @test_replace_with_erase_op
 //  CHECK-EN-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
-//   CHECK-EN-NOT:   test.replace_with_new_op
-//   CHECK-EN-NOT:   test.erase_op
+//   CHECK-EN-NOT:   "test.replace_with_new_op"
+//   CHECK-EN-NOT:   "test.erase_op"
 
 // CHECK-EX-LABEL: func @test_replace_with_erase_op
 //  CHECK-EX-SAME:     {pattern_driver_all_erased = true, pattern_driver_changed = true}
-//   CHECK-EX-NOT:   test.replace_with_new_op
-//       CHECK-EX:   test.erase_op
+//   CHECK-EX-NOT:   "test.replace_with_new_op"
+//       CHECK-EX:   "test.erase_op"
 func.func @test_replace_with_erase_op() {
   "test.replace_with_new_op"() {create_erase_op} : () -> ()
   return
@@ -83,3 +83,149 @@ func.func @test_trigger_rewrite_through_block() {
   // in turn, replaces the successor with bb3.
   "test.implicit_change_op"() [^bb1] : () -> ()
 }
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.foo_b
+// CHECK-AN: notifyOperationRemoved: test.foo_a
+// CHECK-AN: notifyOperationRemoved: test.graph_region
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_graph_region()
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_graph_region() {
+  "test.erase_op"() ({
+    test.graph_region {
+      %0 = "test.foo_a"(%1) : (i1) -> (i1)
+      %1 = "test.foo_b"(%0) : (i1) -> (i1)
+    }
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_cyclic_blocks() {
+  "test.erase_op"() ({
+    %x = "test.dummy_op"() : () -> (i1)
+    cf.br ^bb1(%x: i1)
+  ^bb1(%arg0: i1):
+    "test.foo"(%x) : (i1) -> ()
+    cf.br ^bb2(%arg0: i1)
+  ^bb2(%arg1: i1):
+    "test.bar"(%x) : (i1) -> ()
+    cf.br ^bb1(%arg1: i1)
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: test.qux
+// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
+// CHECK-AN: notifyOperationRemoved: test.nested_dummy
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_dead_blocks()
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_dead_blocks() {
+  "test.erase_op"() ({
+    "test.dummy_op"() : () -> (i1)
+  // The following blocks are not reachable. Still, ^bb2 should be deleted
+  // befire ^bb1.
+  ^bb1(%arg0: i1):
+    "test.foo"() : () -> ()
+    cf.br ^bb2(%arg0: i1)
+  ^bb2(%arg1: i1):
+    "test.nested_dummy"() ({
+      "test.qux"() : () -> ()
+    // The following block is unreachable.
+    ^bb3:
+      "test.qux_unreachable"() : () -> ()
+    }) : () -> ()
+    "test.bar"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// test.nested_* must be deleted before test.foo.
+// test.bar must be deleted before test.foo.
+
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_b
+// CHECK-AN: notifyOperationRemoved: test.nested_a
+// CHECK-AN: notifyOperationRemoved: test.nested_d
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_e
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_c
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_nested_ops()
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_nested_ops() {
+  "test.erase_op"() ({
+    %x = "test.dummy_op"() : () -> (i1)
+    cf.br ^bb1(%x: i1)
+  ^bb1(%arg0: i1):
+    "test.foo"() ({
+      "test.nested_a"() : () -> ()
+      "test.nested_b"() : () -> ()
+    ^dead1:
+      "test.nested_c"() : () -> ()
+      cf.br ^dead3
+    ^dead2:
+      "test.nested_d"() : () -> ()
+    ^dead3:
+      "test.nested_e"() : () -> ()
+      cf.br ^dead2
+    }) : () -> ()
+    cf.br ^bb2(%arg0: i1)
+  ^bb2(%arg1: i1):
+    "test.bar"(%x) : (i1) -> ()
+    cf.br ^bb1(%arg1: i1)
+  }) : () -> ()
+  return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.qux
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.cond_br
+// CHECK-AN-LABEL: func @test_remove_diamond(
+//  CHECK-AN-NEXT:   return
+func.func @test_remove_diamond(%c: i1) {
+  "test.erase_op"() ({
+    cf.cond_br %c, ^bb1, ^bb2
+  ^bb1:
+    "test.foo"() : () -> ()
+    cf.br ^bb3
+  ^bb2:
+    "test.bar"() : () -> ()
+    cf.br ^bb3
+  ^bb3:
+    "test.qux"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e23ed105e383390..2e3bc76009ca208 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -239,6 +239,12 @@ struct TestPatternDriver
       llvm::cl::init(GreedyRewriteConfig().maxIterations)};
 };
 
+struct DumpNotifications : public RewriterBase::Listener {
+  void notifyOperationRemoved(Operation *op) override {
+    llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
+  }
+};
+
 struct TestStrictPatternDriver
     : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
 public:
@@ -275,7 +281,9 @@ struct TestStrictPatternDriver
       }
     });
 
+    DumpNotifications dumpNotifications;
     GreedyRewriteConfig config;
+    config.listener = &dumpNotifications;
     if (strictMode == "AnyOp") {
       config.strictMode = GreedyRewriteStrictness::AnyOp;
     } else if (strictMode == "ExistingAndNewOps") {

@matthias-springer
Copy link
Member Author

This change is already quite old but never got landed. It was accepted on Phabricator (https://reviews.llvm.org/D144193), but I rewrote large parts after it was accepted. Mehdi asked for another review by River, but because it was already accepted it probably did not show up on the review list and I dropped the ball...

@matthias-springer matthias-springer changed the title [mlir][IR] Trigger notifyOperationRemoved callback for nested ops [mlir][IR] Trigger notifyOperationRemoved callback for nested ops Sep 19, 2023
@joker-eph
Copy link
Collaborator

River no longer gets involved with reviews these days, so LGTM.

@matthias-springer matthias-springer merged commit 695a5a6 into llvm:main Sep 20, 2023
HerrCai0907 added a commit that referenced this pull request Mar 5, 2024
#66771 introduce `llvm::post_order(&r.front())` which is equal to
`r.front().getSuccessor(...)`.
It will visit the succ block of current block. But actually here need to
visit all block of region in reverse order.
Fixes: #77420.
matthias-springer added a commit that referenced this pull request Mar 6, 2024
matthias-springer added a commit that referenced this pull request Mar 8, 2024
matthias-springer added a commit that referenced this pull request Mar 10, 2024
matthias-springer added a commit that referenced this pull request Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants