-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][IR] Trigger notifyOperationRemoved
callback for nested ops
#66771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Trigger notifyOperationRemoved
callback for nested ops
#66771
Conversation
When cloning an op, the `notifyOperationInserted` callback is triggered for all nested ops. Similarly, the `notifyOperationRemoved` callback should be triggered for all nested ops when removing an op. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir-core ChangesWhen cloning an op, the Listeners may inspect the IR during a
Note: Imported from https://reviews.llvm.org/D144193. Full diff: https://github.com/llvm/llvm-project/pull/66771.diff 7 Files Affected:
diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h
index 46bfe717533a84a..d6d3aeeb9bd0526 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.h
+++ b/mlir/include/mlir/IR/RegionKindInterface.h
@@ -43,6 +43,12 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
/// not implement the RegionKindInterface.
bool mayHaveSSADominance(Region ®ion);
+/// Return "true" if the given region may be a graph region without SSA
+/// dominance. This function returns "true" in case the owner op is an
+/// unregistered op. It returns "false" if it is a registered op that does not
+/// implement the RegionKindInterface.
+bool mayBeGraphRegion(Region ®ion);
+
} // namespace mlir
#include "mlir/IR/RegionKindInterface.h.inc"
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index cad78b3e65b2313..c34f422292cb4f0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -394,12 +394,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
protected:
void notifyOperationRemoved(Operation *op) override {
- // TODO: Walk can be removed when D144193 has landed.
- op->walk([&](Operation *op) {
- erasedOps.insert(op);
- // Erase if present.
- toMemrefOps.erase(op);
- });
+ erasedOps.insert(op);
+ // Erase if present.
+ toMemrefOps.erase(op);
}
void notifyOperationInserted(Operation *op) override {
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index db920c14ea08dc7..5e9b9b2a810a4c5 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -8,6 +8,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/RegionKindInterface.h"
using namespace mlir;
@@ -275,7 +277,7 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
for (auto it : llvm::zip(op->getResults(), newValues))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- // Erase the op.
+ // Erase op and notify listener.
eraseOp(op);
}
@@ -295,7 +297,7 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- // Erase the old op.
+ // Erase op and notify listener.
eraseOp(op);
}
@@ -303,9 +305,71 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
/// the given operation *must* be known to be dead.
void RewriterBase::eraseOp(Operation *op) {
assert(op->use_empty() && "expected 'op' to have no uses");
- if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+ auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
+
+ // Fast path: If no listener is attached, the op can be dropped in one go.
+ if (!rewriteListener) {
+ op->erase();
+ return;
+ }
+
+ // Helper function that erases a single op.
+ auto eraseSingleOp = [&](Operation *op) {
+#ifndef NDEBUG
+ // All nested ops should have been erased already.
+ assert(
+ llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
+ "expected empty regions");
+ // All users should have been erased already if the op is in a region with
+ // SSA dominance.
+ if (!op->use_empty() && op->getParentOp())
+ assert(mayBeGraphRegion(*op->getParentRegion()) &&
+ "expected that op has no uses");
+#endif // NDEBUG
rewriteListener->notifyOperationRemoved(op);
- op->erase();
+
+ // Explicitly drop all uses in case the op is in a graph region.
+ op->dropAllUses();
+ op->erase();
+ };
+
+ // Nested ops must be erased one-by-one, so that listeners have a consistent
+ // view of the IR every time a notification is triggered. Users must be
+ // erased before definitions. I.e., post-order, reverse dominance.
+ std::function<void(Operation *)> eraseTree = [&](Operation *op) {
+ // Erase nested ops.
+ for (Region &r : llvm::reverse(op->getRegions())) {
+ // Erase all blocks in the right order. Successors should be erased
+ // before predecessors because successor blocks may use values defined
+ // in predecessor blocks. A post-order traversal of blocks within a
+ // region visits successors before predecessors. Repeat the traversal
+ // until the region is empty. (The block graph could be disconnected.)
+ while (!r.empty()) {
+ SmallVector<Block *> erasedBlocks;
+ for (Block *b : llvm::post_order(&r.front())) {
+ // Visit ops in reverse order.
+ for (Operation &op :
+ llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
+ eraseTree(&op);
+ // Do not erase the block immediately. This is not supprted by the
+ // post_order iterator.
+ erasedBlocks.push_back(b);
+ }
+ for (Block *b : erasedBlocks) {
+ // Explicitly drop all uses in case there is a cycle in the block
+ // graph.
+ for (BlockArgument bbArg : b->getArguments())
+ bbArg.dropAllUses();
+ b->dropAllUses();
+ b->erase();
+ }
+ }
+ }
+ // Then erase the enclosing op.
+ eraseSingleOp(op);
+ };
+
+ eraseTree(op);
}
void RewriterBase::eraseBlock(Block *block) {
diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp
index cbef3025a5dd626..007f4cf92dbc7ae 100644
--- a/mlir/lib/IR/RegionKindInterface.cpp
+++ b/mlir/lib/IR/RegionKindInterface.cpp
@@ -18,9 +18,17 @@ using namespace mlir;
#include "mlir/IR/RegionKindInterface.cpp.inc"
bool mlir::mayHaveSSADominance(Region ®ion) {
- auto regionKindOp =
- dyn_cast_if_present<RegionKindInterface>(region.getParentOp());
+ auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
if (!regionKindOp)
return true;
return regionKindOp.hasSSADominance(region.getRegionNumber());
}
+
+bool mlir::mayBeGraphRegion(Region ®ion) {
+ if (!region.getParentOp()->isRegistered())
+ return true;
+ auto regionKindOp = dyn_cast<RegionKindInterface>(region.getParentOp());
+ if (!regionKindOp)
+ return false;
+ return !regionKindOp.hasSSADominance(region.getRegionNumber());
+}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index fba4944f130c230..8e2bfe557c555f3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -421,8 +421,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// If the operation is trivially dead - remove it.
if (isOpTriviallyDead(op)) {
- notifyOperationRemoved(op);
- op->erase();
+ eraseOp(op);
changed = true;
LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -567,10 +566,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
config.listener->notifyOperationRemoved(op);
addOperandsToWorklist(op->getOperands());
- op->walk([this](Operation *operation) {
- worklist.remove(operation);
- folder.notifyRemoval(operation);
- });
+ worklist.remove(op);
+ folder.notifyRemoval(op);
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
diff --git a/mlir/test/Transforms/test-strict-pattern-driver.mlir b/mlir/test/Transforms/test-strict-pattern-driver.mlir
index 5df2d6d1fdeeb38..a5ab8f97c74ce33 100644
--- a/mlir/test/Transforms/test-strict-pattern-driver.mlir
+++ b/mlir/test/Transforms/test-strict-pattern-driver.mlir
@@ -12,9 +12,9 @@
// CHECK-EN-LABEL: func @test_erase
// CHECK-EN-SAME: pattern_driver_all_erased = true, pattern_driver_changed = true}
-// CHECK-EN: test.arg0
-// CHECK-EN: test.arg1
-// CHECK-EN-NOT: test.erase_op
+// CHECK-EN: "test.arg0"
+// CHECK-EN: "test.arg1"
+// CHECK-EN-NOT: "test.erase_op"
func.func @test_erase() {
%0 = "test.arg0"() : () -> (i32)
%1 = "test.arg1"() : () -> (i32)
@@ -51,13 +51,13 @@ func.func @test_replace_with_new_op() {
// CHECK-EN-LABEL: func @test_replace_with_erase_op
// CHECK-EN-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
-// CHECK-EN-NOT: test.replace_with_new_op
-// CHECK-EN-NOT: test.erase_op
+// CHECK-EN-NOT: "test.replace_with_new_op"
+// CHECK-EN-NOT: "test.erase_op"
// CHECK-EX-LABEL: func @test_replace_with_erase_op
// CHECK-EX-SAME: {pattern_driver_all_erased = true, pattern_driver_changed = true}
-// CHECK-EX-NOT: test.replace_with_new_op
-// CHECK-EX: test.erase_op
+// CHECK-EX-NOT: "test.replace_with_new_op"
+// CHECK-EX: "test.erase_op"
func.func @test_replace_with_erase_op() {
"test.replace_with_new_op"() {create_erase_op} : () -> ()
return
@@ -83,3 +83,149 @@ func.func @test_trigger_rewrite_through_block() {
// in turn, replaces the successor with bb3.
"test.implicit_change_op"() [^bb1] : () -> ()
}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.foo_b
+// CHECK-AN: notifyOperationRemoved: test.foo_a
+// CHECK-AN: notifyOperationRemoved: test.graph_region
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_graph_region()
+// CHECK-AN-NEXT: return
+func.func @test_remove_graph_region() {
+ "test.erase_op"() ({
+ test.graph_region {
+ %0 = "test.foo_a"(%1) : (i1) -> (i1)
+ %1 = "test.foo_b"(%0) : (i1) -> (i1)
+ }
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_cyclic_blocks()
+// CHECK-AN-NEXT: return
+func.func @test_remove_cyclic_blocks() {
+ "test.erase_op"() ({
+ %x = "test.dummy_op"() : () -> (i1)
+ cf.br ^bb1(%x: i1)
+ ^bb1(%arg0: i1):
+ "test.foo"(%x) : (i1) -> ()
+ cf.br ^bb2(%arg0: i1)
+ ^bb2(%arg1: i1):
+ "test.bar"(%x) : (i1) -> ()
+ cf.br ^bb1(%arg1: i1)
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: test.qux
+// CHECK-AN: notifyOperationRemoved: test.qux_unreachable
+// CHECK-AN: notifyOperationRemoved: test.nested_dummy
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_dead_blocks()
+// CHECK-AN-NEXT: return
+func.func @test_remove_dead_blocks() {
+ "test.erase_op"() ({
+ "test.dummy_op"() : () -> (i1)
+ // The following blocks are not reachable. Still, ^bb2 should be deleted
+ // befire ^bb1.
+ ^bb1(%arg0: i1):
+ "test.foo"() : () -> ()
+ cf.br ^bb2(%arg0: i1)
+ ^bb2(%arg1: i1):
+ "test.nested_dummy"() ({
+ "test.qux"() : () -> ()
+ // The following block is unreachable.
+ ^bb3:
+ "test.qux_unreachable"() : () -> ()
+ }) : () -> ()
+ "test.bar"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// test.nested_* must be deleted before test.foo.
+// test.bar must be deleted before test.foo.
+
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_b
+// CHECK-AN: notifyOperationRemoved: test.nested_a
+// CHECK-AN: notifyOperationRemoved: test.nested_d
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_e
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.nested_c
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.dummy_op
+// CHECK-AN: notifyOperationRemoved: test.erase_op
+// CHECK-AN-LABEL: func @test_remove_nested_ops()
+// CHECK-AN-NEXT: return
+func.func @test_remove_nested_ops() {
+ "test.erase_op"() ({
+ %x = "test.dummy_op"() : () -> (i1)
+ cf.br ^bb1(%x: i1)
+ ^bb1(%arg0: i1):
+ "test.foo"() ({
+ "test.nested_a"() : () -> ()
+ "test.nested_b"() : () -> ()
+ ^dead1:
+ "test.nested_c"() : () -> ()
+ cf.br ^dead3
+ ^dead2:
+ "test.nested_d"() : () -> ()
+ ^dead3:
+ "test.nested_e"() : () -> ()
+ cf.br ^dead2
+ }) : () -> ()
+ cf.br ^bb2(%arg0: i1)
+ ^bb2(%arg1: i1):
+ "test.bar"(%x) : (i1) -> ()
+ cf.br ^bb1(%arg1: i1)
+ }) : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-AN: notifyOperationRemoved: test.qux
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.foo
+// CHECK-AN: notifyOperationRemoved: cf.br
+// CHECK-AN: notifyOperationRemoved: test.bar
+// CHECK-AN: notifyOperationRemoved: cf.cond_br
+// CHECK-AN-LABEL: func @test_remove_diamond(
+// CHECK-AN-NEXT: return
+func.func @test_remove_diamond(%c: i1) {
+ "test.erase_op"() ({
+ cf.cond_br %c, ^bb1, ^bb2
+ ^bb1:
+ "test.foo"() : () -> ()
+ cf.br ^bb3
+ ^bb2:
+ "test.bar"() : () -> ()
+ cf.br ^bb3
+ ^bb3:
+ "test.qux"() : () -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e23ed105e383390..2e3bc76009ca208 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -239,6 +239,12 @@ struct TestPatternDriver
llvm::cl::init(GreedyRewriteConfig().maxIterations)};
};
+struct DumpNotifications : public RewriterBase::Listener {
+ void notifyOperationRemoved(Operation *op) override {
+ llvm::outs() << "notifyOperationRemoved: " << op->getName() << "\n";
+ }
+};
+
struct TestStrictPatternDriver
: public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> {
public:
@@ -275,7 +281,9 @@ struct TestStrictPatternDriver
}
});
+ DumpNotifications dumpNotifications;
GreedyRewriteConfig config;
+ config.listener = &dumpNotifications;
if (strictMode == "AnyOp") {
config.strictMode = GreedyRewriteStrictness::AnyOp;
} else if (strictMode == "ExistingAndNewOps") {
|
This change is already quite old but never got landed. It was accepted on Phabricator (https://reviews.llvm.org/D144193), but I rewrote large parts after it was accepted. Mehdi asked for another review by River, but because it was already accepted it probably did not show up on the review list and I dropped the ball... |
notifyOperationRemoved
callback for nested ops
River no longer gets involved with reviews these days, so LGTM. |
D144193 (#66771) has been merged.
D144193 (#66771) has been merged.
D144193 (#66771) has been merged.
D144193 (#66771) has been merged.
When cloning an op, the
notifyOperationInserted
callback is triggered for all nested ops. Similarly, thenotifyOperationRemoved
callback should be triggered for all nested ops when removing an op.Listeners may inspect the IR during a
notifyOperationRemoved
callback. Therefore, when multiple ops are removed in a singleRewriterBase::eraseOp
call, the notifications must be triggered in an order in which the ops could have been removed one-by-one:notifyOperationRemoved
callbacks. A callback is triggered right before the respective op is removed.Note: Imported from https://reviews.llvm.org/D144193.