Skip to content

Commit ede5b9e

Browse files
[mlir][Transforms][NFC] Dialect Conversion: Manually push rewrite onto stack
1 parent c6ca8c8 commit ede5b9e

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
869869
/// Append a rewrite. Rewrites are committed upon success and rolled back upon
870870
/// failure.
871871
template <typename RewriteTy, typename... Args>
872-
void appendRewrite(Args &&...args) {
872+
RewriteTy *appendRewrite(Args &&...args) {
873873
rewrites.push_back(
874874
std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
875+
return static_cast<RewriteTy *>(rewrites.back().get());
875876
}
876877

877878
/// Undo the rewrites (motions, splits) one by one in reverse order until
@@ -1181,7 +1182,6 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
11811182
mappedValues(std::move(mappedValues)) {
11821183
assert((!originalType || kind == MaterializationKind::Target) &&
11831184
"original type is valid only for target materializations");
1184-
rewriterImpl.unresolvedMaterializations[op] = this;
11851185
}
11861186

11871187
void UnresolvedMaterializationRewrite::rollback() {
@@ -1471,8 +1471,9 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14711471
mapping.map(valuesToMap, convertOp.getResults());
14721472
if (castOp)
14731473
*castOp = convertOp;
1474-
appendRewrite<UnresolvedMaterializationRewrite>(
1475-
convertOp, converter, kind, originalType, std::move(valuesToMap));
1474+
unresolvedMaterializations[convertOp] =
1475+
appendRewrite<UnresolvedMaterializationRewrite>(
1476+
convertOp, converter, kind, originalType, std::move(valuesToMap));
14761477
return convertOp.getResults();
14771478
}
14781479

0 commit comments

Comments
 (0)