Skip to content

Commit c6ca8c8

Browse files
[mlir][Transforms][NFC] Dialect Conversion: Keep unresolvedMaterializations up to date
1 parent 59d6fbb commit c6ca8c8

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ namespace detail {
848848
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
849849
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
850850
const ConversionConfig &config)
851-
: context(ctx), eraseRewriter(ctx), config(config) {}
851+
: context(ctx), config(config) {}
852852

853853
//===--------------------------------------------------------------------===//
854854
// State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
981981
/// no new IR is created between calls to `eraseOp`/`eraseBlock`.
982982
struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
983983
public:
984-
SingleEraseRewriter(MLIRContext *context)
985-
: RewriterBase(context, /*listener=*/this) {}
984+
SingleEraseRewriter(
985+
MLIRContext *context,
986+
std::function<void(Operation *)> opErasedCallback = nullptr)
987+
: RewriterBase(context, /*listener=*/this),
988+
opErasedCallback(opErasedCallback) {}
986989

987990
/// Erase the given op (unless it was already erased).
988991
void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10031006

10041007
bool wasErased(void *ptr) const { return erased.contains(ptr); }
10051008

1006-
void notifyOperationErased(Operation *op) override { erased.insert(op); }
1009+
void notifyOperationErased(Operation *op) override {
1010+
erased.insert(op);
1011+
if (opErasedCallback)
1012+
opErasedCallback(op);
1013+
}
10071014

10081015
void notifyBlockErased(Block *block) override { erased.insert(block); }
10091016

10101017
private:
10111018
/// Pointers to all erased operations and blocks.
10121019
DenseSet<void *> erased;
1020+
1021+
/// A callback that is invoked when an operation is erased.
1022+
std::function<void(Operation *)> opErasedCallback;
10131023
};
10141024

10151025
//===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10191029
/// MLIR context.
10201030
MLIRContext *context;
10211031

1022-
/// A rewriter that keeps track of ops/block that were already erased and
1023-
/// skips duplicate op/block erasures. This rewriter is used during the
1024-
/// "cleanup" phase.
1025-
SingleEraseRewriter eraseRewriter;
1026-
10271032
// Mapping between replaced values that differ in type. This happens when
10281033
// replacing a value with one of a different type.
10291034
ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
11951200
rewrites[i]->commit(rewriter);
11961201

11971202
// Clean up all rewrites.
1203+
SingleEraseRewriter eraseRewriter(
1204+
context, /*opErasedCallback=*/[&](Operation *op) {
1205+
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1206+
unresolvedMaterializations.erase(castOp);
1207+
});
11981208
for (auto &rewrite : rewrites)
11991209
rewrite->cleanup(eraseRewriter);
12001210
}
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
27142724
SmallVector<UnrealizedConversionCastOp> allCastOps;
27152725
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
27162726
&materializations = rewriterImpl.unresolvedMaterializations;
2717-
for (auto it : materializations) {
2718-
if (rewriterImpl.eraseRewriter.wasErased(it.first))
2719-
continue;
2727+
for (auto it : materializations)
27202728
allCastOps.push_back(it.first);
2721-
}
27222729

27232730
// Reconcile all UnrealizedConversionCastOps that were inserted by the
27242731
// dialect conversion frameworks. (Not the one that were inserted by

0 commit comments

Comments
 (0)