@@ -848,7 +848,7 @@ namespace detail {
848
848
struct ConversionPatternRewriterImpl : public RewriterBase ::Listener {
849
849
explicit ConversionPatternRewriterImpl (MLIRContext *ctx,
850
850
const ConversionConfig &config)
851
- : context(ctx), eraseRewriter(ctx), config(config) {}
851
+ : context(ctx), config(config) {}
852
852
853
853
// ===--------------------------------------------------------------------===//
854
854
// State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
981
981
// / no new IR is created between calls to `eraseOp`/`eraseBlock`.
982
982
struct SingleEraseRewriter : public RewriterBase , RewriterBase::Listener {
983
983
public:
984
- SingleEraseRewriter (MLIRContext *context)
985
- : RewriterBase(context, /* listener=*/ this ) {}
984
+ SingleEraseRewriter (
985
+ MLIRContext *context,
986
+ std::function<void (Operation *)> opErasedCallback = nullptr )
987
+ : RewriterBase(context, /* listener=*/ this ),
988
+ opErasedCallback (opErasedCallback) {}
986
989
987
990
// / Erase the given op (unless it was already erased).
988
991
void eraseOp (Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
1003
1006
1004
1007
bool wasErased (void *ptr) const { return erased.contains (ptr); }
1005
1008
1006
- void notifyOperationErased (Operation *op) override { erased.insert (op); }
1009
+ void notifyOperationErased (Operation *op) override {
1010
+ erased.insert (op);
1011
+ if (opErasedCallback)
1012
+ opErasedCallback (op);
1013
+ }
1007
1014
1008
1015
void notifyBlockErased (Block *block) override { erased.insert (block); }
1009
1016
1010
1017
private:
1011
1018
// / Pointers to all erased operations and blocks.
1012
1019
DenseSet<void *> erased;
1020
+
1021
+ // / A callback that is invoked when an operation is erased.
1022
+ std::function<void (Operation *)> opErasedCallback;
1013
1023
};
1014
1024
1015
1025
// ===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
1019
1029
// / MLIR context.
1020
1030
MLIRContext *context;
1021
1031
1022
- // / A rewriter that keeps track of ops/block that were already erased and
1023
- // / skips duplicate op/block erasures. This rewriter is used during the
1024
- // / "cleanup" phase.
1025
- SingleEraseRewriter eraseRewriter;
1026
-
1027
1032
// Mapping between replaced values that differ in type. This happens when
1028
1033
// replacing a value with one of a different type.
1029
1034
ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
1195
1200
rewrites[i]->commit (rewriter);
1196
1201
1197
1202
// Clean up all rewrites.
1203
+ SingleEraseRewriter eraseRewriter (
1204
+ context, /* opErasedCallback=*/ [&](Operation *op) {
1205
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
1206
+ unresolvedMaterializations.erase (castOp);
1207
+ });
1198
1208
for (auto &rewrite : rewrites)
1199
1209
rewrite->cleanup (eraseRewriter);
1200
1210
}
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
2714
2724
SmallVector<UnrealizedConversionCastOp> allCastOps;
2715
2725
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
2716
2726
&materializations = rewriterImpl.unresolvedMaterializations ;
2717
- for (auto it : materializations) {
2718
- if (rewriterImpl.eraseRewriter .wasErased (it.first ))
2719
- continue ;
2727
+ for (auto it : materializations)
2720
2728
allCastOps.push_back (it.first );
2721
- }
2722
2729
2723
2730
// Reconcile all UnrealizedConversionCastOps that were inserted by the
2724
2731
// dialect conversion frameworks. (Not the one that were inserted by
0 commit comments