Skip to content

Commit c254f78

Browse files
[mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments
This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated. When converting a block signature, a BlockTypeConversionRewrite object and potentially multiple ReplaceBlockArgRewrite are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in BlockTypeConversionRewrite::commit and some were replaced in ReplaceBlockArgRewrite::commit. The new BlockTypeConversionRewrite::commit implementation is much simpler and no longer modifies any IR; that is done only in ReplaceBlockArgRewrite now. The ConvertedArgInfo data structure is no longer needed. To that end, materializations of dropped arguments are now built in applySignatureConversion instead of materializeLiveConversions; the latter function no longer has to deal with dropped arguments. Other minor improvements: Improve variable name: origOutputType -> origArgType. Add an assertion to check that this field is only used for argument materializations. Add more comments to applySignatureConversion. Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in legalizeUnresolvedMaterialization instead of legalizeConvertedArgumentTypes. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion. This is a re-upload of #96207.
1 parent d8a0ebe commit c254f78

File tree

2 files changed

+70
-127
lines changed

2 files changed

+70
-127
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 68 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
432432
Block *insertBeforeBlock;
433433
};
434434

435-
/// This structure contains the information pertaining to an argument that has
436-
/// been converted.
437-
struct ConvertedArgInfo {
438-
ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
439-
Value castValue = nullptr)
440-
: newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
441-
442-
/// The start index of in the new argument list that contains arguments that
443-
/// replace the original.
444-
unsigned newArgIdx;
445-
446-
/// The number of arguments that replaced the original argument.
447-
unsigned newArgSize;
448-
449-
/// The cast value that was created to cast from the new arguments to the
450-
/// old. This only used if 'newArgSize' > 1.
451-
Value castValue;
452-
};
453-
454435
/// Block type conversion. This rewrite is partially reflected in the IR.
455436
class BlockTypeConversionRewrite : public BlockRewrite {
456437
public:
457-
BlockTypeConversionRewrite(
458-
ConversionPatternRewriterImpl &rewriterImpl, Block *block,
459-
Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
460-
const TypeConverter *converter)
438+
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
439+
Block *block, Block *origBlock,
440+
const TypeConverter *converter)
461441
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
462-
origBlock(origBlock), argInfo(argInfo), converter(converter) {}
442+
origBlock(origBlock), converter(converter) {}
463443

464444
static bool classof(const IRRewrite *rewrite) {
465445
return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
479459
/// The original block that was requested to have its signature converted.
480460
Block *origBlock;
481461

482-
/// The conversion information for each of the arguments. The information is
483-
/// std::nullopt if the argument was dropped during conversion.
484-
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
485-
486462
/// The type converter used to convert the arguments.
487463
const TypeConverter *converter;
488464
};
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
691667
/// The type of materialization.
692668
enum MaterializationKind {
693669
/// This materialization materializes a conversion for an illegal block
694-
/// argument type, to a legal one.
670+
/// argument type, to the original one.
695671
Argument,
696672

697673
/// This materialization materializes a conversion from an illegal type to a
698674
/// legal one.
699-
Target
675+
Target,
676+
677+
/// This materialization materializes a conversion from a legal type back to
678+
/// an illegal one.
679+
Source
700680
};
701681

702682
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
736716
private:
737717
/// The corresponding type converter to use when resolving this
738718
/// materialization, and the kind of this materialization.
739-
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
719+
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
740720
converterAndKind;
741721
};
742722
} // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
855835
ValueRange inputs, Type outputType,
856836
const TypeConverter *converter);
857837

858-
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
859-
ValueRange inputs,
860-
Type outputType,
861-
const TypeConverter *converter);
862-
863838
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
864839
Type outputType,
865840
const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
989964
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
990965
for (Operation *op : block->getUsers())
991966
listener->notifyOperationModified(op);
992-
993-
// Process the remapping for each of the original arguments.
994-
for (auto [origArg, info] :
995-
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
996-
// Handle the case of a 1->0 value mapping.
997-
if (!info) {
998-
if (Value newArg =
999-
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
1000-
rewriter.replaceAllUsesWith(origArg, newArg);
1001-
continue;
1002-
}
1003-
1004-
// Otherwise this is a 1->1+ value mapping.
1005-
Value castValue = info->castValue;
1006-
assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
1007-
1008-
// If the argument is still used, replace it with the generated cast.
1009-
if (!origArg.use_empty()) {
1010-
rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
1011-
castValue, origArg.getType()));
1012-
}
1013-
}
1014967
}
1015968

1016969
void BlockTypeConversionRewrite::rollback() {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1035988
continue;
1036989

1037990
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
1038-
bool isDroppedArg = replacementValue == origArg;
1039-
if (!isDroppedArg)
1040-
builder.setInsertionPointAfterValue(replacementValue);
991+
assert(replacementValue && "replacement value not found");
1041992
Value newArg;
1042993
if (converter) {
994+
builder.setInsertionPointAfterValue(replacementValue);
1043995
newArg = converter->materializeSourceConversion(
1044-
builder, origArg.getLoc(), origArg.getType(),
1045-
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
996+
builder, origArg.getLoc(), origArg.getType(), replacementValue);
1046997
assert((!newArg || newArg.getType() == origArg.getType()) &&
1047998
"materialization hook did not provide a value of the expected "
1048999
"type");
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
10531004
<< "failed to materialize conversion for block argument #"
10541005
<< it.index() << " that remained live after conversion, type was "
10551006
<< origArg.getType();
1056-
if (!isDroppedArg)
1057-
diag << ", with target type " << replacementValue.getType();
10581007
diag.attachNote(liveUser->getLoc())
10591008
<< "see existing live user here: " << *liveUser;
10601009
return failure();
@@ -1340,72 +1289,71 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13401289
// Replace all uses of the old block with the new block.
13411290
block->replaceAllUsesWith(newBlock);
13421291

1343-
// Remap each of the original arguments as determined by the signature
1344-
// conversion.
1345-
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
1346-
argInfo.resize(origArgCount);
1347-
13481292
for (unsigned i = 0; i != origArgCount; ++i) {
1349-
auto inputMap = signatureConversion.getInputMapping(i);
1350-
if (!inputMap)
1351-
continue;
13521293
BlockArgument origArg = block->getArgument(i);
1294+
Type origArgType = origArg.getType();
1295+
1296+
// Helper function that tries to legalize the given type. Returns the given
1297+
// type if it could not be legalized.
1298+
// FIXME: We simply pass through the replacement argument if there wasn't a
1299+
// converter, which isn't great as it allows implicit type conversions to
1300+
// appear. We should properly restructure this code to handle cases where a
1301+
// converter isn't provided and also to properly handle the case where an
1302+
// argument materialization is actually a temporary source materialization
1303+
// (e.g. in the case of 1->N).
1304+
auto tryLegalizeType = [&](Type type) {
1305+
if (converter)
1306+
if (Type t = converter->convertType(type))
1307+
return t;
1308+
return type;
1309+
};
1310+
1311+
std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1312+
signatureConversion.getInputMapping(i);
1313+
if (!inputMap) {
1314+
// This block argument was dropped and no replacement value was provided.
1315+
// Materialize a replacement value "out of thin air".
1316+
Value repl = buildUnresolvedMaterialization(
1317+
MaterializationKind::Source, newBlock, newBlock->begin(),
1318+
origArg.getLoc(), /*inputs=*/ValueRange(),
1319+
/*outputType=*/origArgType, converter);
1320+
mapping.map(origArg, repl);
1321+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1322+
continue;
1323+
}
13531324

1354-
// If inputMap->replacementValue is not nullptr, then the argument is
1355-
// dropped and a replacement value is provided to be the remappedValue.
1356-
if (inputMap->replacementValue) {
1325+
if (Value repl = inputMap->replacementValue) {
1326+
// This block argument was dropped and a replacement value was provided.
13571327
assert(inputMap->size == 0 &&
13581328
"invalid to provide a replacement value when the argument isn't "
13591329
"dropped");
1360-
mapping.map(origArg, inputMap->replacementValue);
1330+
mapping.map(origArg, repl);
13611331
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13621332
continue;
13631333
}
13641334

1365-
// Otherwise, this is a 1->1+ mapping.
1335+
// This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1336+
// dialect conversion. Therefore, we need an argument materialization to
1337+
// turn the replacement block arguments into a single SSA value that can be
1338+
// used as a replacement.
13661339
auto replArgs =
13671340
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1368-
Value newArg;
1341+
Value argMat = buildUnresolvedMaterialization(
1342+
MaterializationKind::Argument, newBlock, newBlock->begin(),
1343+
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
1344+
mapping.map(origArg, argMat);
1345+
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
13691346

1370-
// If this is a 1->1 mapping and the types of new and replacement arguments
1371-
// match (i.e. it's an identity map), then the argument is mapped to its
1372-
// original type.
1373-
// FIXME: We simply pass through the replacement argument if there wasn't a
1374-
// converter, which isn't great as it allows implicit type conversions to
1375-
// appear. We should properly restructure this code to handle cases where a
1376-
// converter isn't provided and also to properly handle the case where an
1377-
// argument materialization is actually a temporary source materialization
1378-
// (e.g. in the case of 1->N).
1379-
if (replArgs.size() == 1 &&
1380-
(!converter || replArgs[0].getType() == origArg.getType())) {
1381-
newArg = replArgs.front();
1382-
mapping.map(origArg, newArg);
1383-
} else {
1384-
// Build argument materialization: new block arguments -> old block
1385-
// argument type.
1386-
Value argMat = buildUnresolvedArgumentMaterialization(
1387-
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
1388-
mapping.map(origArg, argMat);
1389-
1390-
// Build target materialization: old block argument type -> legal type.
1391-
// Note: This function returns an "empty" type if no valid conversion to
1392-
// a legal type exists. In that case, we continue the conversion with the
1393-
// original block argument type.
1394-
if (Type legalOutputType = converter->convertType(origArg.getType())) {
1395-
newArg = buildUnresolvedTargetMaterialization(
1396-
origArg.getLoc(), argMat, legalOutputType, converter);
1397-
mapping.map(argMat, newArg);
1398-
} else {
1399-
newArg = argMat;
1400-
}
1347+
Type legalOutputType = tryLegalizeType(origArgType);
1348+
if (legalOutputType != origArgType) {
1349+
Value targetMat = buildUnresolvedTargetMaterialization(
1350+
origArg.getLoc(), argMat, legalOutputType, converter);
1351+
mapping.map(argMat, targetMat);
14011352
}
1402-
14031353
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1404-
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
14051354
}
14061355

1407-
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1408-
converter);
1356+
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
14091357

14101358
// Erase the old block. (It is just unlinked for now and will be erased during
14111359
// cleanup.)
@@ -1436,13 +1384,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14361384
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
14371385
return convertOp.getResult(0);
14381386
}
1439-
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
1440-
Block *block, Location loc, ValueRange inputs, Type outputType,
1441-
const TypeConverter *converter) {
1442-
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
1443-
block->begin(), loc, inputs, outputType,
1444-
converter);
1445-
}
14461387
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14471388
Location loc, Value input, Type outputType,
14481389
const TypeConverter *converter) {
@@ -2861,6 +2802,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
28612802
newMaterialization = converter->materializeTargetConversion(
28622803
rewriter, op->getLoc(), outputType, inputOperands);
28632804
break;
2805+
case MaterializationKind::Source:
2806+
newMaterialization = converter->materializeSourceConversion(
2807+
rewriter, op->getLoc(), outputType, inputOperands);
2808+
break;
28642809
}
28652810
if (newMaterialization) {
28662811
assert(newMaterialization.getType() == outputType &&
@@ -2873,8 +2818,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
28732818

28742819
InFlightDiagnostic diag = op->emitError()
28752820
<< "failed to legalize unresolved materialization "
2876-
"from "
2877-
<< inputOperands.getTypes() << " to " << outputType
2821+
"from ("
2822+
<< inputOperands.getTypes() << ") to " << outputType
28782823
<< " that remained live after conversion";
28792824
if (Operation *liveUser = findLiveUser(op->getUsers())) {
28802825
diag.attachNote(liveUser->getLoc())

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33

44
func.func @test_invalid_arg_materialization(
5-
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
5+
// expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
66
%arg0: i16) {
7-
// expected-note@below {{see existing live user here}}
87
"foo.return"(%arg0) : (i16) -> ()
98
}
109

@@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
104103
// Make sure argument type changes aren't implicitly forwarded.
105104
func.func @test_signature_conversion_no_converter() {
106105
"test.signature_conversion_no_converter"() ({
107-
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
106+
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
108107
^bb0(%arg0: f32):
109-
// expected-note@below {{see existing live user here}}
110108
"test.type_consumer"(%arg0) : (f32) -> ()
111109
"test.return"(%arg0) : (f32) -> ()
112110
}) : () -> ()

0 commit comments

Comments
 (0)