diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 1e0afee2373a9..0b552a7e1ca3b 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite { Block *insertBeforeBlock; }; -/// This structure contains the information pertaining to an argument that has -/// been converted. -struct ConvertedArgInfo { - ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, - Value castValue = nullptr) - : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} - - /// The start index of in the new argument list that contains arguments that - /// replace the original. - unsigned newArgIdx; - - /// The number of arguments that replaced the original argument. - unsigned newArgSize; - - /// The cast value that was created to cast from the new arguments to the - /// old. This only used if 'newArgSize' > 1. - Value castValue; -}; - /// Block type conversion. This rewrite is partially reflected in the IR. class BlockTypeConversionRewrite : public BlockRewrite { public: - BlockTypeConversionRewrite( - ConversionPatternRewriterImpl &rewriterImpl, Block *block, - Block *origBlock, SmallVector, 1> argInfo, - const TypeConverter *converter) + BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Block *block, Block *origBlock, + const TypeConverter *converter) : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block), - origBlock(origBlock), argInfo(argInfo), converter(converter) {} + origBlock(origBlock), converter(converter) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::BlockTypeConversion; @@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite { /// The original block that was requested to have its signature converted. Block *origBlock; - /// The conversion information for each of the arguments. The information is - /// std::nullopt if the argument was dropped during conversion. - SmallVector, 1> argInfo; - /// The type converter used to convert the arguments. const TypeConverter *converter; }; @@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite { /// The type of materialization. enum MaterializationKind { /// This materialization materializes a conversion for an illegal block - /// argument type, to a legal one. + /// argument type, to the original one. Argument, /// This materialization materializes a conversion from an illegal type to a /// legal one. - Target + Target, + + /// This materialization materializes a conversion from a legal type back to + /// an illegal one. + Source }; /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" @@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { private: /// The corresponding type converter to use when resolving this /// materialization, and the kind of this materialization. - llvm::PointerIntPair + llvm::PointerIntPair converterAndKind; }; } // namespace @@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ValueRange inputs, Type outputType, const TypeConverter *converter); - Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, - ValueRange inputs, - Type outputType, - const TypeConverter *converter); - Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter); @@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { dyn_cast_or_null(rewriter.getListener())) for (Operation *op : block->getUsers()) listener->notifyOperationModified(op); - - // Process the remapping for each of the original arguments. - for (auto [origArg, info] : - llvm::zip_equal(origBlock->getArguments(), argInfo)) { - // Handle the case of a 1->0 value mapping. - if (!info) { - if (Value newArg = - rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) - rewriter.replaceAllUsesWith(origArg, newArg); - continue; - } - - // Otherwise this is a 1->1+ value mapping. - Value castValue = info->castValue; - assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); - - // If the argument is still used, replace it with the generated cast. - if (!origArg.use_empty()) { - rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault( - castValue, origArg.getType())); - } - } } void BlockTypeConversionRewrite::rollback() { @@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( continue; Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg); - bool isDroppedArg = replacementValue == origArg; - if (!isDroppedArg) - builder.setInsertionPointAfterValue(replacementValue); + assert(replacementValue && "replacement value not found"); Value newArg; if (converter) { + builder.setInsertionPointAfterValue(replacementValue); newArg = converter->materializeSourceConversion( - builder, origArg.getLoc(), origArg.getType(), - isDroppedArg ? ValueRange() : ValueRange(replacementValue)); + builder, origArg.getLoc(), origArg.getType(), replacementValue); assert((!newArg || newArg.getType() == origArg.getType()) && "materialization hook did not provide a value of the expected " "type"); @@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( << "failed to materialize conversion for block argument #" << it.index() << " that remained live after conversion, type was " << origArg.getType(); - if (!isDroppedArg) - diag << ", with target type " << replacementValue.getType(); diag.attachNote(liveUser->getLoc()) << "see existing live user here: " << *liveUser; return failure(); @@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // Replace all uses of the old block with the new block. block->replaceAllUsesWith(newBlock); - // Remap each of the original arguments as determined by the signature - // conversion. - SmallVector, 1> argInfo; - argInfo.resize(origArgCount); - for (unsigned i = 0; i != origArgCount; ++i) { - auto inputMap = signatureConversion.getInputMapping(i); - if (!inputMap) - continue; BlockArgument origArg = block->getArgument(i); + Type origArgType = origArg.getType(); + + std::optional inputMap = + signatureConversion.getInputMapping(i); + if (!inputMap) { + // This block argument was dropped and no replacement value was provided. + // Materialize a replacement value "out of thin air". + Value repl = buildUnresolvedMaterialization( + MaterializationKind::Source, newBlock, newBlock->begin(), + origArg.getLoc(), /*inputs=*/ValueRange(), + /*outputType=*/origArgType, converter); + mapping.map(origArg, repl); + appendRewrite(block, origArg); + continue; + } - // If inputMap->replacementValue is not nullptr, then the argument is - // dropped and a replacement value is provided to be the remappedValue. - if (inputMap->replacementValue) { + if (Value repl = inputMap->replacementValue) { + // This block argument was dropped and a replacement value was provided. assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - mapping.map(origArg, inputMap->replacementValue); + mapping.map(origArg, repl); appendRewrite(block, origArg); continue; } - // Otherwise, this is a 1->1+ mapping. + // This is a 1->1+ mapping. 1->N mappings are not fully supported in the + // dialect conversion. Therefore, we need an argument materialization to + // turn the replacement block arguments into a single SSA value that can be + // used as a replacement. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - Value newArg; + Value argMat = buildUnresolvedMaterialization( + MaterializationKind::Argument, newBlock, newBlock->begin(), + origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter); + mapping.map(origArg, argMat); + appendRewrite(block, origArg); - // If this is a 1->1 mapping and the types of new and replacement arguments - // match (i.e. it's an identity map), then the argument is mapped to its - // original type. // FIXME: We simply pass through the replacement argument if there wasn't a // converter, which isn't great as it allows implicit type conversions to // appear. We should properly restructure this code to handle cases where a // converter isn't provided and also to properly handle the case where an // argument materialization is actually a temporary source materialization // (e.g. in the case of 1->N). - if (replArgs.size() == 1 && - (!converter || replArgs[0].getType() == origArg.getType())) { - newArg = replArgs.front(); - mapping.map(origArg, newArg); - } else { - // Build argument materialization: new block arguments -> old block - // argument type. - Value argMat = buildUnresolvedArgumentMaterialization( - newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter); - mapping.map(origArg, argMat); - - // Build target materialization: old block argument type -> legal type. - // Note: This function returns an "empty" type if no valid conversion to - // a legal type exists. In that case, we continue the conversion with the - // original block argument type. - Type legalOutputType = converter->convertType(origArg.getType()); - if (legalOutputType && legalOutputType != origArg.getType()) { - newArg = buildUnresolvedTargetMaterialization( - origArg.getLoc(), argMat, legalOutputType, converter); - mapping.map(argMat, newArg); - } else { - newArg = argMat; - } + Type legalOutputType; + if (converter) + legalOutputType = converter->convertType(origArgType); + if (legalOutputType && legalOutputType != origArgType) { + Value targetMat = buildUnresolvedTargetMaterialization( + origArg.getLoc(), argMat, legalOutputType, converter); + mapping.map(argMat, targetMat); } - appendRewrite(block, origArg); - argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } - appendRewrite(newBlock, block, argInfo, - converter); + appendRewrite(newBlock, block, converter); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( appendRewrite(convertOp, converter, kind); return convertOp.getResult(0); } -Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization( - Block *block, Location loc, ValueRange inputs, Type outputType, - const TypeConverter *converter) { - return buildUnresolvedMaterialization(MaterializationKind::Argument, block, - block->begin(), loc, inputs, outputType, - converter); -} Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( Location loc, Value input, Type outputType, const TypeConverter *converter) { @@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization( newMaterialization = converter->materializeTargetConversion( rewriter, op->getLoc(), outputType, inputOperands); break; + case MaterializationKind::Source: + newMaterialization = converter->materializeSourceConversion( + rewriter, op->getLoc(), outputType, inputOperands); + break; } if (newMaterialization) { assert(newMaterialization.getType() == outputType && @@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization( InFlightDiagnostic diag = op->emitError() << "failed to legalize unresolved materialization " - "from " - << inputOperands.getTypes() << " to " << outputType + "from (" + << inputOperands.getTypes() << ") to " << outputType << " that remained live after conversion"; if (Operation *liveUser = findLiveUser(op->getUsers())) { diag.attachNote(liveUser->getLoc()) diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir index b35cda8e724f6..8254be68912c8 100644 --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -2,9 +2,8 @@ func.func @test_invalid_arg_materialization( - // expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}} + // expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}} %arg0: i16) { - // expected-note@below {{see existing live user here}} "foo.return"(%arg0) : (i16) -> () } @@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() { // Make sure argument type changes aren't implicitly forwarded. func.func @test_signature_conversion_no_converter() { "test.signature_conversion_no_converter"() ({ - // expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}} + // expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}} ^bb0(%arg0: f32): - // expected-note@below {{see existing live user here}} "test.type_consumer"(%arg0) : (f32) -> () "test.return"(%arg0) : (f32) -> () }) : () -> ()