Skip to content

Commit 6eca32f

Browse files
matthias-springerAlexisPerry
authored andcommitted
[mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (llvm#96329)
This commit moves the argument materialization logic from `legalizeConvertedArgumentTypes` to `legalizeUnresolvedMaterializations`. Before this change: - Argument materializations were created in `legalizeConvertedArgumentTypes` (which used to call `materializeLiveConversions`). After this change: - `legalizeConvertedArgumentTypes` creates a "placeholder" `unrealized_conversion_cast`. - The placeholder `unrealized_conversion_cast` is replaced with an argument materialization (using the type converter) in `legalizeUnresolvedMaterializations`. - All argument and target materializations now take place in the same location (`legalizeUnresolvedMaterializations`). This commit brings us closer towards creating all source/target/argument materializations in one central step, which can then be made optional (and delegated to the user) in the future. (There is one more source materialization step that has not been moved yet.) This commit also consolidates all `build*UnresolvedMaterialization` functions into a single `buildUnresolvedMaterialization` function.
1 parent d55dded commit 6eca32f

File tree

1 file changed

+52
-81
lines changed

1 file changed

+52
-81
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 52 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
5353
});
5454
}
5555

56+
/// Helper function that computes an insertion point where the given value is
57+
/// defined and can be used without a dominance violation.
58+
static OpBuilder::InsertPoint computeInsertPoint(Value value) {
59+
Block *insertBlock = value.getParentBlock();
60+
Block::iterator insertPt = insertBlock->begin();
61+
if (OpResult inputRes = dyn_cast<OpResult>(value))
62+
insertPt = ++inputRes.getOwner()->getIterator();
63+
return OpBuilder::InsertPoint(insertBlock, insertPt);
64+
}
65+
5666
//===----------------------------------------------------------------------===//
5767
// ConversionValueMapping
5868
//===----------------------------------------------------------------------===//
@@ -445,11 +455,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
445455
return rewrite->getKind() == Kind::BlockTypeConversion;
446456
}
447457

448-
/// Materialize any necessary conversions for converted arguments that have
449-
/// live users, using the provided `findLiveUser` to search for a user that
450-
/// survives the conversion process.
451-
LogicalResult
452-
materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
458+
Block *getOrigBlock() const { return origBlock; }
459+
460+
const TypeConverter *getConverter() const { return converter; }
453461

454462
void commit(RewriterBase &rewriter) override;
455463

@@ -841,14 +849,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
841849
/// Build an unresolved materialization operation given an output type and set
842850
/// of input operands.
843851
Value buildUnresolvedMaterialization(MaterializationKind kind,
844-
Block *insertBlock,
845-
Block::iterator insertPt, Location loc,
852+
OpBuilder::InsertPoint ip, Location loc,
846853
ValueRange inputs, Type outputType,
847854
Type origOutputType,
848855
const TypeConverter *converter);
849-
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
850-
Type outputType,
851-
const TypeConverter *converter);
852856

853857
//===--------------------------------------------------------------------===//
854858
// Rewriter Notification Hooks
@@ -981,49 +985,6 @@ void BlockTypeConversionRewrite::rollback() {
981985
block->replaceAllUsesWith(origBlock);
982986
}
983987

984-
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
985-
function_ref<Operation *(Value)> findLiveUser) {
986-
// Process the remapping for each of the original arguments.
987-
for (auto it : llvm::enumerate(origBlock->getArguments())) {
988-
BlockArgument origArg = it.value();
989-
// Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
990-
OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
991-
builder.setInsertionPointToStart(block);
992-
993-
// If the type of this argument changed and the argument is still live, we
994-
// need to materialize a conversion.
995-
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
996-
continue;
997-
Operation *liveUser = findLiveUser(origArg);
998-
if (!liveUser)
999-
continue;
1000-
1001-
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
1002-
assert(replacementValue && "replacement value not found");
1003-
Value newArg;
1004-
if (converter) {
1005-
builder.setInsertionPointAfterValue(replacementValue);
1006-
newArg = converter->materializeSourceConversion(
1007-
builder, origArg.getLoc(), origArg.getType(), replacementValue);
1008-
assert((!newArg || newArg.getType() == origArg.getType()) &&
1009-
"materialization hook did not provide a value of the expected "
1010-
"type");
1011-
}
1012-
if (!newArg) {
1013-
InFlightDiagnostic diag =
1014-
emitError(origArg.getLoc())
1015-
<< "failed to materialize conversion for block argument #"
1016-
<< it.index() << " that remained live after conversion, type was "
1017-
<< origArg.getType();
1018-
diag.attachNote(liveUser->getLoc())
1019-
<< "see existing live user here: " << *liveUser;
1020-
return failure();
1021-
}
1022-
rewriterImpl.mapping.map(origArg, newArg);
1023-
}
1024-
return success();
1025-
}
1026-
1027988
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1028989
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
1029990
if (!repl)
@@ -1196,8 +1157,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11961157
Type newOperandType = newOperand.getType();
11971158
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
11981159
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
1199-
Value castValue = buildUnresolvedTargetMaterialization(
1200-
operandLoc, newOperand, desiredType, currentTypeConverter);
1160+
Value castValue = buildUnresolvedMaterialization(
1161+
MaterializationKind::Target, computeInsertPoint(newOperand),
1162+
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
1163+
/*origArgType=*/{}, currentTypeConverter);
12011164
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
12021165
newOperand = castValue;
12031166
}
@@ -1325,8 +1288,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13251288
// This block argument was dropped and no replacement value was provided.
13261289
// Materialize a replacement value "out of thin air".
13271290
Value repl = buildUnresolvedMaterialization(
1328-
MaterializationKind::Source, newBlock, newBlock->begin(),
1329-
origArg.getLoc(), /*inputs=*/ValueRange(),
1291+
MaterializationKind::Source,
1292+
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1293+
/*inputs=*/ValueRange(),
13301294
/*outputType=*/origArgType, /*origArgType=*/{}, converter);
13311295
mapping.map(origArg, repl);
13321296
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1351,8 +1315,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13511315
auto replArgs =
13521316
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
13531317
Value repl = buildUnresolvedMaterialization(
1354-
MaterializationKind::Argument, newBlock, newBlock->begin(),
1355-
origArg.getLoc(), /*inputs=*/replArgs,
1318+
MaterializationKind::Argument,
1319+
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1320+
/*inputs=*/replArgs,
13561321
/*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
13571322
mapping.map(origArg, repl);
13581323
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1374,34 +1339,22 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13741339
/// Build an unresolved materialization operation given an output type and set
13751340
/// of input operands.
13761341
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1377-
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1378-
Location loc, ValueRange inputs, Type outputType, Type origArgType,
1342+
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1343+
ValueRange inputs, Type outputType, Type origArgType,
13791344
const TypeConverter *converter) {
13801345
// Avoid materializing an unnecessary cast.
13811346
if (inputs.size() == 1 && inputs.front().getType() == outputType)
13821347
return inputs.front();
13831348

13841349
// Create an unresolved materialization. We use a new OpBuilder to avoid
13851350
// tracking the materialization like we do for other operations.
1386-
OpBuilder builder(insertBlock, insertPt);
1351+
OpBuilder builder(ip.getBlock(), ip.getPoint());
13871352
auto convertOp =
13881353
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
13891354
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
13901355
origArgType);
13911356
return convertOp.getResult(0);
13921357
}
1393-
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
1394-
Location loc, Value input, Type outputType,
1395-
const TypeConverter *converter) {
1396-
Block *insertBlock = input.getParentBlock();
1397-
Block::iterator insertPt = insertBlock->begin();
1398-
if (OpResult inputRes = dyn_cast<OpResult>(input))
1399-
insertPt = ++inputRes.getOwner()->getIterator();
1400-
1401-
return buildUnresolvedMaterialization(
1402-
MaterializationKind::Target, insertBlock, insertPt, loc, input,
1403-
outputType, /*origArgType=*/{}, converter);
1404-
}
14051358

14061359
//===----------------------------------------------------------------------===//
14071360
// Rewriter Notification Hooks
@@ -2515,9 +2468,9 @@ LogicalResult
25152468
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
25162469
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
25172470
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
2518-
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2519-
inverseMapping)) ||
2520-
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
2471+
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
2472+
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
2473+
inverseMapping)))
25212474
return failure();
25222475

25232476
// Process requested operation replacements.
@@ -2573,10 +2526,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25732526
++i) {
25742527
auto &rewrite = rewriterImpl.rewrites[i];
25752528
if (auto *blockTypeConversionRewrite =
2576-
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
2577-
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
2578-
findLiveUser)))
2579-
return failure();
2529+
dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
2530+
// Process the remapping for each of the original arguments.
2531+
for (Value origArg :
2532+
blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
2533+
// If the type of this argument changed and the argument is still live,
2534+
// we need to materialize a conversion.
2535+
if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
2536+
continue;
2537+
Operation *liveUser = findLiveUser(origArg);
2538+
if (!liveUser)
2539+
continue;
2540+
2541+
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
2542+
assert(replacementValue && "replacement value not found");
2543+
Value repl = rewriterImpl.buildUnresolvedMaterialization(
2544+
MaterializationKind::Source, computeInsertPoint(replacementValue),
2545+
origArg.getLoc(), /*inputs=*/replacementValue,
2546+
/*outputType=*/origArg.getType(), /*origArgType=*/{},
2547+
blockTypeConversionRewrite->getConverter());
2548+
rewriterImpl.mapping.map(origArg, repl);
2549+
}
2550+
}
25802551
}
25812552
return success();
25822553
}

0 commit comments

Comments
 (0)