@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
432
432
Block *insertBeforeBlock;
433
433
};
434
434
435
- // / This structure contains the information pertaining to an argument that has
436
- // / been converted.
437
- struct ConvertedArgInfo {
438
- ConvertedArgInfo (unsigned newArgIdx, unsigned newArgSize,
439
- Value castValue = nullptr )
440
- : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
441
-
442
- // / The start index of in the new argument list that contains arguments that
443
- // / replace the original.
444
- unsigned newArgIdx;
445
-
446
- // / The number of arguments that replaced the original argument.
447
- unsigned newArgSize;
448
-
449
- // / The cast value that was created to cast from the new arguments to the
450
- // / old. This only used if 'newArgSize' > 1.
451
- Value castValue;
452
- };
453
-
454
435
// / Block type conversion. This rewrite is partially reflected in the IR.
455
436
class BlockTypeConversionRewrite : public BlockRewrite {
456
437
public:
457
- BlockTypeConversionRewrite (
458
- ConversionPatternRewriterImpl &rewriterImpl, Block *block,
459
- Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo,
460
- const TypeConverter *converter)
438
+ BlockTypeConversionRewrite (ConversionPatternRewriterImpl &rewriterImpl,
439
+ Block *block, Block *origBlock,
440
+ const TypeConverter *converter)
461
441
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
462
- origBlock (origBlock), argInfo(argInfo), converter(converter) {}
442
+ origBlock (origBlock), converter(converter) {}
463
443
464
444
static bool classof (const IRRewrite *rewrite) {
465
445
return rewrite->getKind () == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
479
459
// / The original block that was requested to have its signature converted.
480
460
Block *origBlock;
481
461
482
- // / The conversion information for each of the arguments. The information is
483
- // / std::nullopt if the argument was dropped during conversion.
484
- SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo;
485
-
486
462
// / The type converter used to convert the arguments.
487
463
const TypeConverter *converter;
488
464
};
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
691
667
// / The type of materialization.
692
668
enum MaterializationKind {
693
669
// / This materialization materializes a conversion for an illegal block
694
- // / argument type, to a legal one.
670
+ // / argument type, to the original one.
695
671
Argument,
696
672
697
673
// / This materialization materializes a conversion from an illegal type to a
698
674
// / legal one.
699
- Target
675
+ Target,
676
+
677
+ // / This materialization materializes a conversion from a legal type back to
678
+ // / an illegal one.
679
+ Source
700
680
};
701
681
702
682
// / An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
736
716
private:
737
717
// / The corresponding type converter to use when resolving this
738
718
// / materialization, and the kind of this materialization.
739
- llvm::PointerIntPair<const TypeConverter *, 1 , MaterializationKind>
719
+ llvm::PointerIntPair<const TypeConverter *, 2 , MaterializationKind>
740
720
converterAndKind;
741
721
};
742
722
} // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
855
835
ValueRange inputs, Type outputType,
856
836
const TypeConverter *converter);
857
837
858
- Value buildUnresolvedArgumentMaterialization (Block *block, Location loc,
859
- ValueRange inputs,
860
- Type outputType,
861
- const TypeConverter *converter);
862
-
863
838
Value buildUnresolvedTargetMaterialization (Location loc, Value input,
864
839
Type outputType,
865
840
const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
989
964
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener ()))
990
965
for (Operation *op : block->getUsers ())
991
966
listener->notifyOperationModified (op);
992
-
993
- // Process the remapping for each of the original arguments.
994
- for (auto [origArg, info] :
995
- llvm::zip_equal (origBlock->getArguments (), argInfo)) {
996
- // Handle the case of a 1->0 value mapping.
997
- if (!info) {
998
- if (Value newArg =
999
- rewriterImpl.mapping .lookupOrNull (origArg, origArg.getType ()))
1000
- rewriter.replaceAllUsesWith (origArg, newArg);
1001
- continue ;
1002
- }
1003
-
1004
- // Otherwise this is a 1->1+ value mapping.
1005
- Value castValue = info->castValue ;
1006
- assert (info->newArgSize >= 1 && castValue && " expected 1->1+ mapping" );
1007
-
1008
- // If the argument is still used, replace it with the generated cast.
1009
- if (!origArg.use_empty ()) {
1010
- rewriter.replaceAllUsesWith (origArg, rewriterImpl.mapping .lookupOrDefault (
1011
- castValue, origArg.getType ()));
1012
- }
1013
- }
1014
967
}
1015
968
1016
969
void BlockTypeConversionRewrite::rollback () {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1035
988
continue ;
1036
989
1037
990
Value replacementValue = rewriterImpl.mapping .lookupOrDefault (origArg);
1038
- bool isDroppedArg = replacementValue == origArg;
1039
- if (!isDroppedArg)
1040
- builder.setInsertionPointAfterValue (replacementValue);
991
+ assert (replacementValue && " replacement value not found" );
1041
992
Value newArg;
1042
993
if (converter) {
994
+ builder.setInsertionPointAfterValue (replacementValue);
1043
995
newArg = converter->materializeSourceConversion (
1044
- builder, origArg.getLoc (), origArg.getType (),
1045
- isDroppedArg ? ValueRange () : ValueRange (replacementValue));
996
+ builder, origArg.getLoc (), origArg.getType (), replacementValue);
1046
997
assert ((!newArg || newArg.getType () == origArg.getType ()) &&
1047
998
" materialization hook did not provide a value of the expected "
1048
999
" type" );
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
1053
1004
<< " failed to materialize conversion for block argument #"
1054
1005
<< it.index () << " that remained live after conversion, type was "
1055
1006
<< origArg.getType ();
1056
- if (!isDroppedArg)
1057
- diag << " , with target type " << replacementValue.getType ();
1058
1007
diag.attachNote (liveUser->getLoc ())
1059
1008
<< " see existing live user here: " << *liveUser;
1060
1009
return failure ();
@@ -1340,72 +1289,71 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1340
1289
// Replace all uses of the old block with the new block.
1341
1290
block->replaceAllUsesWith (newBlock);
1342
1291
1343
- // Remap each of the original arguments as determined by the signature
1344
- // conversion.
1345
- SmallVector<std::optional<ConvertedArgInfo>, 1 > argInfo;
1346
- argInfo.resize (origArgCount);
1347
-
1348
1292
for (unsigned i = 0 ; i != origArgCount; ++i) {
1349
- auto inputMap = signatureConversion.getInputMapping (i);
1350
- if (!inputMap)
1351
- continue ;
1352
1293
BlockArgument origArg = block->getArgument (i);
1294
+ Type origArgType = origArg.getType ();
1295
+
1296
+ // Helper function that tries to legalize the given type. Returns the given
1297
+ // type if it could not be legalized.
1298
+ // FIXME: We simply pass through the replacement argument if there wasn't a
1299
+ // converter, which isn't great as it allows implicit type conversions to
1300
+ // appear. We should properly restructure this code to handle cases where a
1301
+ // converter isn't provided and also to properly handle the case where an
1302
+ // argument materialization is actually a temporary source materialization
1303
+ // (e.g. in the case of 1->N).
1304
+ auto tryLegalizeType = [&](Type type) {
1305
+ if (converter)
1306
+ if (Type t = converter->convertType (type))
1307
+ return t;
1308
+ return type;
1309
+ };
1310
+
1311
+ std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
1312
+ signatureConversion.getInputMapping (i);
1313
+ if (!inputMap) {
1314
+ // This block argument was dropped and no replacement value was provided.
1315
+ // Materialize a replacement value "out of thin air".
1316
+ Value repl = buildUnresolvedMaterialization (
1317
+ MaterializationKind::Source, newBlock, newBlock->begin (),
1318
+ origArg.getLoc (), /* inputs=*/ ValueRange (),
1319
+ /* outputType=*/ origArgType, converter);
1320
+ mapping.map (origArg, repl);
1321
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1322
+ continue ;
1323
+ }
1353
1324
1354
- // If inputMap->replacementValue is not nullptr, then the argument is
1355
- // dropped and a replacement value is provided to be the remappedValue.
1356
- if (inputMap->replacementValue ) {
1325
+ if (Value repl = inputMap->replacementValue ) {
1326
+ // This block argument was dropped and a replacement value was provided.
1357
1327
assert (inputMap->size == 0 &&
1358
1328
" invalid to provide a replacement value when the argument isn't "
1359
1329
" dropped" );
1360
- mapping.map (origArg, inputMap-> replacementValue );
1330
+ mapping.map (origArg, repl );
1361
1331
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1362
1332
continue ;
1363
1333
}
1364
1334
1365
- // Otherwise, this is a 1->1+ mapping.
1335
+ // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
1336
+ // dialect conversion. Therefore, we need an argument materialization to
1337
+ // turn the replacement block arguments into a single SSA value that can be
1338
+ // used as a replacement.
1366
1339
auto replArgs =
1367
1340
newBlock->getArguments ().slice (inputMap->inputNo , inputMap->size );
1368
- Value newArg;
1341
+ Value argMat = buildUnresolvedMaterialization (
1342
+ MaterializationKind::Argument, newBlock, newBlock->begin (),
1343
+ origArg.getLoc (), /* inputs=*/ replArgs, origArgType, converter);
1344
+ mapping.map (origArg, argMat);
1345
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1369
1346
1370
- // If this is a 1->1 mapping and the types of new and replacement arguments
1371
- // match (i.e. it's an identity map), then the argument is mapped to its
1372
- // original type.
1373
- // FIXME: We simply pass through the replacement argument if there wasn't a
1374
- // converter, which isn't great as it allows implicit type conversions to
1375
- // appear. We should properly restructure this code to handle cases where a
1376
- // converter isn't provided and also to properly handle the case where an
1377
- // argument materialization is actually a temporary source materialization
1378
- // (e.g. in the case of 1->N).
1379
- if (replArgs.size () == 1 &&
1380
- (!converter || replArgs[0 ].getType () == origArg.getType ())) {
1381
- newArg = replArgs.front ();
1382
- mapping.map (origArg, newArg);
1383
- } else {
1384
- // Build argument materialization: new block arguments -> old block
1385
- // argument type.
1386
- Value argMat = buildUnresolvedArgumentMaterialization (
1387
- newBlock, origArg.getLoc (), replArgs, origArg.getType (), converter);
1388
- mapping.map (origArg, argMat);
1389
-
1390
- // Build target materialization: old block argument type -> legal type.
1391
- // Note: This function returns an "empty" type if no valid conversion to
1392
- // a legal type exists. In that case, we continue the conversion with the
1393
- // original block argument type.
1394
- if (Type legalOutputType = converter->convertType (origArg.getType ())) {
1395
- newArg = buildUnresolvedTargetMaterialization (
1396
- origArg.getLoc (), argMat, legalOutputType, converter);
1397
- mapping.map (argMat, newArg);
1398
- } else {
1399
- newArg = argMat;
1400
- }
1347
+ Type legalOutputType = tryLegalizeType (origArgType);
1348
+ if (legalOutputType != origArgType) {
1349
+ Value targetMat = buildUnresolvedTargetMaterialization (
1350
+ origArg.getLoc (), argMat, legalOutputType, converter);
1351
+ mapping.map (argMat, targetMat);
1401
1352
}
1402
-
1403
1353
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1404
- argInfo[i] = ConvertedArgInfo (inputMap->inputNo , inputMap->size , newArg);
1405
1354
}
1406
1355
1407
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
1408
- converter);
1356
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
1409
1357
1410
1358
// Erase the old block. (It is just unlinked for now and will be erased during
1411
1359
// cleanup.)
@@ -1436,13 +1384,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1436
1384
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1437
1385
return convertOp.getResult (0 );
1438
1386
}
1439
- Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization (
1440
- Block *block, Location loc, ValueRange inputs, Type outputType,
1441
- const TypeConverter *converter) {
1442
- return buildUnresolvedMaterialization (MaterializationKind::Argument, block,
1443
- block->begin (), loc, inputs, outputType,
1444
- converter);
1445
- }
1446
1387
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization (
1447
1388
Location loc, Value input, Type outputType,
1448
1389
const TypeConverter *converter) {
@@ -2861,6 +2802,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
2861
2802
newMaterialization = converter->materializeTargetConversion (
2862
2803
rewriter, op->getLoc (), outputType, inputOperands);
2863
2804
break ;
2805
+ case MaterializationKind::Source:
2806
+ newMaterialization = converter->materializeSourceConversion (
2807
+ rewriter, op->getLoc (), outputType, inputOperands);
2808
+ break ;
2864
2809
}
2865
2810
if (newMaterialization) {
2866
2811
assert (newMaterialization.getType () == outputType &&
@@ -2873,8 +2818,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
2873
2818
2874
2819
InFlightDiagnostic diag = op->emitError ()
2875
2820
<< " failed to legalize unresolved materialization "
2876
- " from "
2877
- << inputOperands.getTypes () << " to " << outputType
2821
+ " from ( "
2822
+ << inputOperands.getTypes () << " ) to " << outputType
2878
2823
<< " that remained live after conversion" ;
2879
2824
if (Operation *liveUser = findLiveUser (op->getUsers ())) {
2880
2825
diag.attachNote (liveUser->getLoc ())
0 commit comments