@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
707
707
UnresolvedMaterializationRewrite (
708
708
ConversionPatternRewriterImpl &rewriterImpl,
709
709
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr ,
710
- MaterializationKind kind = MaterializationKind::Target,
711
- Type origOutputType = nullptr )
710
+ MaterializationKind kind = MaterializationKind::Target)
712
711
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713
- converterAndKind (converter, kind), origOutputType(origOutputType) {}
712
+ converterAndKind (converter, kind) {}
714
713
715
714
static bool classof (const IRRewrite *rewrite) {
716
715
return rewrite->getKind () == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
734
733
return converterAndKind.getInt ();
735
734
}
736
735
737
- // / Return the original illegal output type of the input values.
738
- Type getOrigOutputType () const { return origOutputType; }
739
-
740
736
private:
741
737
// / The corresponding type converter to use when resolving this
742
738
// / materialization, and the kind of this materialization.
743
739
llvm::PointerIntPair<const TypeConverter *, 1 , MaterializationKind>
744
740
converterAndKind;
745
-
746
- // / The original output type. This is only used for argument conversions.
747
- Type origOutputType;
748
741
};
749
742
} // namespace
750
743
@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
860
853
Block *insertBlock,
861
854
Block::iterator insertPt, Location loc,
862
855
ValueRange inputs, Type outputType,
863
- Type origOutputType,
864
856
const TypeConverter *converter);
865
857
866
858
Value buildUnresolvedArgumentMaterialization (Block *block, Location loc,
867
859
ValueRange inputs,
868
- Type origOutputType,
869
860
Type outputType,
870
861
const TypeConverter *converter);
871
862
@@ -1388,20 +1379,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1388
1379
if (replArgs.size () == 1 &&
1389
1380
(!converter || replArgs[0 ].getType () == origArg.getType ())) {
1390
1381
newArg = replArgs.front ();
1382
+ mapping.map (origArg, newArg);
1391
1383
} else {
1392
- Type origOutputType = origArg.getType ();
1393
-
1394
- // Legalize the argument output type.
1395
- Type outputType = origOutputType;
1396
- if (Type legalOutputType = converter->convertType (outputType))
1397
- outputType = legalOutputType;
1398
-
1399
- newArg = buildUnresolvedArgumentMaterialization (
1400
- newBlock, origArg.getLoc (), replArgs, origOutputType, outputType,
1401
- converter);
1384
+ // Build argument materialization: new block arguments -> old block
1385
+ // argument type.
1386
+ Value argMat = buildUnresolvedArgumentMaterialization (
1387
+ newBlock, origArg.getLoc (), replArgs, origArg.getType (), converter);
1388
+ mapping.map (origArg, argMat);
1389
+
1390
+ // Build target materialization: old block argument type -> legal type.
1391
+ // Note: This function returns an "empty" type if no valid conversion to
1392
+ // a legal type exists. In that case, we continue the conversion with the
1393
+ // original block argument type.
1394
+ Type legalOutputType = converter->convertType (origArg.getType ());
1395
+ if (legalOutputType && legalOutputType != origArg.getType ()) {
1396
+ newArg = buildUnresolvedTargetMaterialization (
1397
+ origArg.getLoc (), argMat, legalOutputType, converter);
1398
+ mapping.map (argMat, newArg);
1399
+ } else {
1400
+ newArg = argMat;
1401
+ }
1402
1402
}
1403
1403
1404
- mapping.map (origArg, newArg);
1405
1404
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
1406
1405
argInfo[i] = ConvertedArgInfo (inputMap->inputNo , inputMap->size , newArg);
1407
1406
}
@@ -1424,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
1424
1423
// / of input operands.
1425
1424
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization (
1426
1425
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1427
- Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1426
+ Location loc, ValueRange inputs, Type outputType,
1428
1427
const TypeConverter *converter) {
1429
1428
// Avoid materializing an unnecessary cast.
1430
1429
if (inputs.size () == 1 && inputs.front ().getType () == outputType)
@@ -1435,16 +1434,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
1435
1434
OpBuilder builder (insertBlock, insertPt);
1436
1435
auto convertOp =
1437
1436
builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1438
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1439
- origOutputType);
1437
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1440
1438
return convertOp.getResult (0 );
1441
1439
}
1442
1440
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization (
1443
- Block *block, Location loc, ValueRange inputs, Type origOutputType ,
1444
- Type outputType, const TypeConverter *converter) {
1441
+ Block *block, Location loc, ValueRange inputs, Type outputType ,
1442
+ const TypeConverter *converter) {
1445
1443
return buildUnresolvedMaterialization (MaterializationKind::Argument, block,
1446
1444
block->begin (), loc, inputs, outputType,
1447
- origOutputType, converter);
1445
+ converter);
1448
1446
}
1449
1447
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization (
1450
1448
Location loc, Value input, Type outputType,
@@ -1456,7 +1454,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
1456
1454
1457
1455
return buildUnresolvedMaterialization (MaterializationKind::Target,
1458
1456
insertBlock, insertPt, loc, input,
1459
- outputType, outputType, converter);
1457
+ outputType, converter);
1460
1458
}
1461
1459
1462
1460
// ===----------------------------------------------------------------------===//
@@ -2672,19 +2670,28 @@ static void computeNecessaryMaterializations(
2672
2670
ConversionPatternRewriterImpl &rewriterImpl,
2673
2671
DenseMap<Value, SmallVector<Value>> &inverseMapping,
2674
2672
SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2673
+ // Helper function to check if the given value or a not yet materialized
2674
+ // replacement of the given value is live.
2675
+ // Note: `inverseMapping` maps from replaced values to original values.
2675
2676
auto isLive = [&](Value value) {
2676
2677
auto findFn = [&](Operation *user) {
2677
2678
auto matIt = materializationOps.find (user);
2678
2679
if (matIt != materializationOps.end ())
2679
2680
return !necessaryMaterializations.count (matIt->second );
2680
2681
return rewriterImpl.isOpIgnored (user);
2681
2682
};
2682
- // This value may be replacing another value that has a live user.
2683
- for (Value inv : inverseMapping.lookup (value))
2684
- if (llvm::find_if_not (inv.getUsers (), findFn) != inv.user_end ())
2683
+ // A worklist is needed because a value may have gone through a chain of
2684
+ // replacements and each of the replaced values may have live users.
2685
+ SmallVector<Value> worklist;
2686
+ worklist.push_back (value);
2687
+ while (!worklist.empty ()) {
2688
+ Value next = worklist.pop_back_val ();
2689
+ if (llvm::find_if_not (next.getUsers (), findFn) != next.user_end ())
2685
2690
return true ;
2686
- // Or have live users itself.
2687
- return llvm::find_if_not (value.getUsers (), findFn) != value.user_end ();
2691
+ // This value may be replacing another value that has a live user.
2692
+ llvm::append_range (worklist, inverseMapping.lookup (next));
2693
+ }
2694
+ return false ;
2688
2695
};
2689
2696
2690
2697
llvm::unique_function<Value (Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2851,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
2844
2851
switch (mat.getMaterializationKind ()) {
2845
2852
case MaterializationKind::Argument:
2846
2853
// Try to materialize an argument conversion.
2847
- // FIXME: The current argument materialization hook expects the original
2848
- // output type, even though it doesn't use that as the actual output type
2849
- // of the generated IR. The output type is just used as an indicator of
2850
- // the type of materialization to do. This behavior is really awkward in
2851
- // that it diverges from the behavior of the other hooks, and can be
2852
- // easily misunderstood. We should clean up the argument hooks to better
2853
- // represent the desired invariants we actually care about.
2854
2854
newMaterialization = converter->materializeArgumentConversion (
2855
- rewriter, op->getLoc (), mat. getOrigOutputType () , inputOperands);
2855
+ rewriter, op->getLoc (), outputType , inputOperands);
2856
2856
if (newMaterialization)
2857
2857
break ;
2858
-
2859
2858
// If an argument materialization failed, fallback to trying a target
2860
2859
// materialization.
2861
2860
[[fallthrough]];
@@ -2865,6 +2864,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
2865
2864
break ;
2866
2865
}
2867
2866
if (newMaterialization) {
2867
+ assert (newMaterialization.getType () == outputType &&
2868
+ " materialization callback produced value of incorrect type" );
2868
2869
replaceMaterialization (rewriterImpl, opResult, newMaterialization,
2869
2870
inverseMapping);
2870
2871
return success ();
0 commit comments