diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 7de26d7cfa84d..c4b85ec4f67d6 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -848,7 +848,7 @@ namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config) - : context(ctx), eraseRewriter(ctx), config(config) {} + : context(ctx), config(config) {} //===--------------------------------------------------------------------===// // State Management @@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// no new IR is created between calls to `eraseOp`/`eraseBlock`. struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener { public: - SingleEraseRewriter(MLIRContext *context) - : RewriterBase(context, /*listener=*/this) {} + SingleEraseRewriter( + MLIRContext *context, + std::function opErasedCallback = nullptr) + : RewriterBase(context, /*listener=*/this), + opErasedCallback(opErasedCallback) {} /// Erase the given op (unless it was already erased). void eraseOp(Operation *op) override { @@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasErased(void *ptr) const { return erased.contains(ptr); } - void notifyOperationErased(Operation *op) override { erased.insert(op); } + void notifyOperationErased(Operation *op) override { + erased.insert(op); + if (opErasedCallback) + opErasedCallback(op); + } void notifyBlockErased(Block *block) override { erased.insert(block); } private: /// Pointers to all erased operations and blocks. DenseSet erased; + + /// A callback that is invoked when an operation is erased. + std::function opErasedCallback; }; //===--------------------------------------------------------------------===// @@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// MLIR context. MLIRContext *context; - /// A rewriter that keeps track of ops/block that were already erased and - /// skips duplicate op/block erasures. This rewriter is used during the - /// "cleanup" phase. - SingleEraseRewriter eraseRewriter; - // Mapping between replaced values that differ in type. This happens when // replacing a value with one of a different type. ConversionValueMapping mapping; @@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() { rewrites[i]->commit(rewriter); // Clean up all rewrites. + SingleEraseRewriter eraseRewriter( + context, /*opErasedCallback=*/[&](Operation *op) { + if (auto castOp = dyn_cast(op)) + unresolvedMaterializations.erase(castOp); + }); for (auto &rewrite : rewrites) rewrite->cleanup(eraseRewriter); } @@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { SmallVector allCastOps; const DenseMap &materializations = rewriterImpl.unresolvedMaterializations; - for (auto it : materializations) { - if (rewriterImpl.eraseRewriter.wasErased(it.first)) - continue; + for (auto it : materializations) allCastOps.push_back(it.first); - } // Reconcile all UnrealizedConversionCastOps that were inserted by the // dialect conversion frameworks. (Not the one that were inserted by