diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 07ebd687ee2b3..47e03383304af 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { }); } +/// Helper function that computes an insertion point where the given value is +/// defined and can be used without a dominance violation. +static OpBuilder::InsertPoint computeInsertPoint(Value value) { + Block *insertBlock = value.getParentBlock(); + Block::iterator insertPt = insertBlock->begin(); + if (OpResult inputRes = dyn_cast(value)) + insertPt = ++inputRes.getOwner()->getIterator(); + return OpBuilder::InsertPoint(insertBlock, insertPt); +} + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// @@ -445,11 +455,9 @@ class BlockTypeConversionRewrite : public BlockRewrite { return rewrite->getKind() == Kind::BlockTypeConversion; } - /// Materialize any necessary conversions for converted arguments that have - /// live users, using the provided `findLiveUser` to search for a user that - /// survives the conversion process. - LogicalResult - materializeLiveConversions(function_ref findLiveUser); + Block *getOrigBlock() const { return origBlock; } + + const TypeConverter *getConverter() const { return converter; } void commit(RewriterBase &rewriter) override; @@ -841,14 +849,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Build an unresolved materialization operation given an output type and set /// of input operands. Value buildUnresolvedMaterialization(MaterializationKind kind, - Block *insertBlock, - Block::iterator insertPt, Location loc, + OpBuilder::InsertPoint ip, Location loc, ValueRange inputs, Type outputType, Type origOutputType, const TypeConverter *converter); - Value buildUnresolvedTargetMaterialization(Location loc, Value input, - Type outputType, - const TypeConverter *converter); //===--------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -981,49 +985,6 @@ void BlockTypeConversionRewrite::rollback() { block->replaceAllUsesWith(origBlock); } -LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( - function_ref findLiveUser) { - // Process the remapping for each of the original arguments. - for (auto it : llvm::enumerate(origBlock->getArguments())) { - BlockArgument origArg = it.value(); - // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used. - OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl); - builder.setInsertionPointToStart(block); - - // If the type of this argument changed and the argument is still live, we - // need to materialize a conversion. - if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) - continue; - Operation *liveUser = findLiveUser(origArg); - if (!liveUser) - continue; - - Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg); - assert(replacementValue && "replacement value not found"); - Value newArg; - if (converter) { - builder.setInsertionPointAfterValue(replacementValue); - newArg = converter->materializeSourceConversion( - builder, origArg.getLoc(), origArg.getType(), replacementValue); - assert((!newArg || newArg.getType() == origArg.getType()) && - "materialization hook did not provide a value of the expected " - "type"); - } - if (!newArg) { - InFlightDiagnostic diag = - emitError(origArg.getLoc()) - << "failed to materialize conversion for block argument #" - << it.index() << " that remained live after conversion, type was " - << origArg.getType(); - diag.attachNote(liveUser->getLoc()) - << "see existing live user here: " << *liveUser; - return failure(); - } - rewriterImpl.mapping.map(origArg, newArg); - } - return success(); -} - void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType()); if (!repl) @@ -1196,8 +1157,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Type newOperandType = newOperand.getType(); if (currentTypeConverter && desiredType && newOperandType != desiredType) { Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); - Value castValue = buildUnresolvedTargetMaterialization( - operandLoc, newOperand, desiredType, currentTypeConverter); + Value castValue = buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(newOperand), + operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType, + /*origArgType=*/{}, currentTypeConverter); mapping.map(mapping.lookupOrDefault(newOperand), castValue); newOperand = castValue; } @@ -1325,8 +1288,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // This block argument was dropped and no replacement value was provided. // Materialize a replacement value "out of thin air". Value repl = buildUnresolvedMaterialization( - MaterializationKind::Source, newBlock, newBlock->begin(), - origArg.getLoc(), /*inputs=*/ValueRange(), + MaterializationKind::Source, + OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), + /*inputs=*/ValueRange(), /*outputType=*/origArgType, /*origArgType=*/{}, converter); mapping.map(origArg, repl); appendRewrite(block, origArg); @@ -1351,8 +1315,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); Value repl = buildUnresolvedMaterialization( - MaterializationKind::Argument, newBlock, newBlock->begin(), - origArg.getLoc(), /*inputs=*/replArgs, + MaterializationKind::Argument, + OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), + /*inputs=*/replArgs, /*outputType=*/tryLegalizeType(origArgType), origArgType, converter); mapping.map(origArg, repl); appendRewrite(block, origArg); @@ -1374,8 +1339,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// Build an unresolved materialization operation given an output type and set /// of input operands. Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( - MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, - Location loc, ValueRange inputs, Type outputType, Type origArgType, + MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, + ValueRange inputs, Type outputType, Type origArgType, const TypeConverter *converter) { // Avoid materializing an unnecessary cast. if (inputs.size() == 1 && inputs.front().getType() == outputType) @@ -1383,25 +1348,13 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // Create an unresolved materialization. We use a new OpBuilder to avoid // tracking the materialization like we do for other operations. - OpBuilder builder(insertBlock, insertPt); + OpBuilder builder(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create(loc, outputType, inputs); appendRewrite(convertOp, converter, kind, origArgType); return convertOp.getResult(0); } -Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( - Location loc, Value input, Type outputType, - const TypeConverter *converter) { - Block *insertBlock = input.getParentBlock(); - Block::iterator insertPt = insertBlock->begin(); - if (OpResult inputRes = dyn_cast(input)) - insertPt = ++inputRes.getOwner()->getIterator(); - - return buildUnresolvedMaterialization( - MaterializationKind::Target, insertBlock, insertPt, loc, input, - outputType, /*origArgType=*/{}, converter); -} //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -2515,9 +2468,9 @@ LogicalResult OperationConverter::finalize(ConversionPatternRewriter &rewriter) { std::optional>> inverseMapping; ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl, - inverseMapping)) || - failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) + if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) || + failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl, + inverseMapping))) return failure(); // Process requested operation replacements. @@ -2573,10 +2526,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( ++i) { auto &rewrite = rewriterImpl.rewrites[i]; if (auto *blockTypeConversionRewrite = - dyn_cast(rewrite.get())) - if (failed(blockTypeConversionRewrite->materializeLiveConversions( - findLiveUser))) - return failure(); + dyn_cast(rewrite.get())) { + // Process the remapping for each of the original arguments. + for (Value origArg : + blockTypeConversionRewrite->getOrigBlock()->getArguments()) { + // If the type of this argument changed and the argument is still live, + // we need to materialize a conversion. + if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) + continue; + Operation *liveUser = findLiveUser(origArg); + if (!liveUser) + continue; + + Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg); + assert(replacementValue && "replacement value not found"); + Value repl = rewriterImpl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(replacementValue), + origArg.getLoc(), /*inputs=*/replacementValue, + /*outputType=*/origArg.getType(), /*origArgType=*/{}, + blockTypeConversionRewrite->getConverter()); + rewriterImpl.mapping.map(origArg, repl); + } + } } return success(); }