@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
707
707
UnresolvedMaterializationRewrite (
708
708
ConversionPatternRewriterImpl &rewriterImpl,
709
709
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr ,
710
- MaterializationKind kind = MaterializationKind::Target,
711
- Type origOutputType = nullptr )
710
+ MaterializationKind kind = MaterializationKind::Target)
712
711
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713
- converterAndKind (converter, kind), origOutputType(origOutputType) {}
712
+ converterAndKind (converter, kind) {}
714
713
715
714
static bool classof (const IRRewrite *rewrite) {
716
715
return rewrite->getKind () == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
734
733
return converterAndKind.getInt ();
735
734
}
736
735
737
- // / Return the original illegal output type of the input values.
738
- Type getOrigOutputType () const { return origOutputType; }
739
-
740
736
private:
741
737
// / The corresponding type converter to use when resolving this
742
738
// / materialization, and the kind of this materialization.
743
739
llvm::PointerIntPair<const TypeConverter *, 1 , MaterializationKind>
744
740
converterAndKind;
745
-
746
- // / The original output type. This is only used for argument conversions.
747
- Type origOutputType;
748
741
};
749
742
} // namespace
750
743
@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
860
853
Block *insertBlock,
861
854
Block::iterator insertPt, Location loc,
862
855
ValueRange inputs, Type outputType,
863
- Type origOutputType,
864
856
const TypeConverter *converter);
865
857
866
858
Value buildUnresolvedArgumentMaterialization (Block *block, Location loc,
867
859
ValueRange inputs,
868
- Type origOutputType,
869
860
Type outputType,
870
861
const TypeConverter *converter);
871
862
@@ -1388,20 +1379,24 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1388
1379
if (replArgs.size () == 1 &&
1389
1380
(!converter || replArgs[0 ].getType () == origArg.getType ())) {
1390
1381
newArg = replArgs.front ();
1382
+ mapping.map (origArg, newArg);
1391
1383
} else {
1392
- Type origOutputType = origArg.getType ();
1393
-
1394
- // Legalize the argument output type.
1395
- Type outputType = origOutputType;
1396
- if (Type legalOutputType = converter->convertType (outputType))
1397
- outputType = legalOutputType;
1398
-
1399
- newArg = buildUnresolvedArgumentMaterialization (
1400
- newBlock, origArg.getLoc (), replArgs, origOutputType, outputType,
1401
- converter);
1384
+ // Build argument materialization: new block arguments -> old block
1385
+ // argument type.
1386
+ Value argMat = buildUnresolvedArgumentMaterialization (
1387
+ newBlock, origArg.getLoc (), replArgs, origArg.getType (), converter);
1388
+ mapping.map (origArg, argMat);
1389
+
1390
+ // Build target materialization: old block argument type -> legal type.
1391
+ if (Type legalOutputType = converter->convertType (origArg.getType ())) {
1392
+ newArg = buildUnresolvedTargetMaterialization (
1393
+ origArg.getLoc (), argMat, legalOutputType, converter);
1394
+ mapping.map (argMat, newArg);
1395
+ } else {
1396
+ newArg = argMat;
1397
+ }
1402
1398
}
1403
1399
1404
- mapping.map (origArg, newArg);
1405
1400
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1406
1401
argInfo[i] = ConvertedArgInfo (inputMap->inputNo , inputMap->size , newArg);
1407
1402
}
@@ -1424,7 +1419,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1424
1419
// / of input operands.
1425
1420
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization (
1426
1421
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1427
- Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1422
+ Location loc, ValueRange inputs, Type outputType,
1428
1423
const TypeConverter *converter) {
1429
1424
// Avoid materializing an unnecessary cast.
1430
1425
if (inputs.size () == 1 && inputs.front ().getType () == outputType)
@@ -1435,16 +1430,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1435
1430
OpBuilder builder (insertBlock, insertPt);
1436
1431
auto convertOp =
1437
1432
builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1438
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1439
- origOutputType);
1433
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1440
1434
return convertOp.getResult (0 );
1441
1435
}
1442
1436
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization (
1443
- Block *block, Location loc, ValueRange inputs, Type origOutputType ,
1444
- Type outputType, const TypeConverter *converter) {
1437
+ Block *block, Location loc, ValueRange inputs, Type outputType ,
1438
+ const TypeConverter *converter) {
1445
1439
return buildUnresolvedMaterialization (MaterializationKind::Argument, block,
1446
1440
block->begin (), loc, inputs, outputType,
1447
- origOutputType, converter);
1441
+ converter);
1448
1442
}
1449
1443
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization (
1450
1444
Location loc, Value input, Type outputType,
@@ -1456,7 +1450,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
1456
1450
1457
1451
return buildUnresolvedMaterialization (MaterializationKind::Target,
1458
1452
insertBlock, insertPt, loc, input,
1459
- outputType, outputType, converter);
1453
+ outputType, converter);
1460
1454
}
1461
1455
1462
1456
// ===----------------------------------------------------------------------===//
@@ -2672,19 +2666,28 @@ static void computeNecessaryMaterializations(
2672
2666
ConversionPatternRewriterImpl &rewriterImpl,
2673
2667
DenseMap<Value, SmallVector<Value>> &inverseMapping,
2674
2668
SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2669
+ // Helper function to check if the given value or a not yet materialized
2670
+ // replacement of the given value is live.
2671
+ // Note: `inverseMapping` maps from replaced values to original values.
2675
2672
auto isLive = [&](Value value) {
2676
2673
auto findFn = [&](Operation *user) {
2677
2674
auto matIt = materializationOps.find (user);
2678
2675
if (matIt != materializationOps.end ())
2679
2676
return !necessaryMaterializations.count (matIt->second );
2680
2677
return rewriterImpl.isOpIgnored (user);
2681
2678
};
2682
- // This value may be replacing another value that has a live user.
2683
- for (Value inv : inverseMapping.lookup (value))
2684
- if (llvm::find_if_not (inv.getUsers (), findFn) != inv.user_end ())
2679
+ // A worklist is needed because a value may have gone through a chain of
2680
+ // replacements and each of the replaced values may have live users.
2681
+ SmallVector<Value> worklist;
2682
+ worklist.push_back (value);
2683
+ while (!worklist.empty ()) {
2684
+ Value next = worklist.pop_back_val ();
2685
+ if (llvm::find_if_not (next.getUsers (), findFn) != next.user_end ())
2685
2686
return true ;
2686
- // Or have live users itself.
2687
- return llvm::find_if_not (value.getUsers (), findFn) != value.user_end ();
2687
+ // This value may be replacing another value that has a live user.
2688
+ llvm::append_range (worklist, inverseMapping.lookup (next));
2689
+ }
2690
+ return false ;
2688
2691
};
2689
2692
2690
2693
llvm::unique_function<Value (Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2847,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
2844
2847
switch (mat.getMaterializationKind ()) {
2845
2848
case MaterializationKind::Argument:
2846
2849
// Try to materialize an argument conversion.
2847
- // FIXME: The current argument materialization hook expects the original
2848
- // output type, even though it doesn't use that as the actual output type
2849
- // of the generated IR. The output type is just used as an indicator of
2850
- // the type of materialization to do. This behavior is really awkward in
2851
- // that it diverges from the behavior of the other hooks, and can be
2852
- // easily misunderstood. We should clean up the argument hooks to better
2853
- // represent the desired invariants we actually care about.
2854
2850
newMaterialization = converter->materializeArgumentConversion (
2855
- rewriter, op->getLoc (), mat. getOrigOutputType () , inputOperands);
2851
+ rewriter, op->getLoc (), outputType , inputOperands);
2856
2852
if (newMaterialization)
2857
2853
break ;
2858
-
2859
2854
// If an argument materialization failed, fallback to trying a target
2860
2855
// materialization.
2861
2856
[[fallthrough]];
@@ -2865,6 +2860,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
2865
2860
break ;
2866
2861
}
2867
2862
if (newMaterialization) {
2863
+ assert (newMaterialization.getType () == opResult.getType () &&
2864
+ " materialization callback produced value of incorrect type" );
2868
2865
replaceMaterialization (rewriterImpl, opResult, newMaterialization,
2869
2866
inverseMapping);
2870
2867
return success ();
0 commit comments