@@ -3002,13 +3002,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
3002
3002
// instruction cost.
3003
3003
return 0 ;
3004
3004
case Instruction::Call: {
3005
- if (!isSingleScalar ()) {
3006
- // TODO: Handle remaining call costs here as well.
3007
- if (VF.isScalable ())
3008
- return InstructionCost::getInvalid ();
3009
- break ;
3010
- }
3011
-
3012
3005
auto *CalledFn =
3013
3006
cast<Function>(getOperand (getNumOperands () - 1 )->getLiveInIRValue ());
3014
3007
if (CalledFn->isIntrinsic ())
@@ -3017,8 +3010,43 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
3017
3010
SmallVector<Type *, 4 > Tys;
3018
3011
for (VPValue *ArgOp : drop_end (operands ()))
3019
3012
Tys.push_back (Ctx.Types .inferScalarType (ArgOp));
3013
+
3020
3014
Type *ResultTy = Ctx.Types .inferScalarType (this );
3021
- return Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
3015
+ InstructionCost ScalarCallCost =
3016
+ Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
3017
+ if (isSingleScalar ())
3018
+ return ScalarCallCost;
3019
+
3020
+ if (VF.isScalable ())
3021
+ return InstructionCost::getInvalid ();
3022
+
3023
+ // Compute the cost of scalarizing the result and operands if needed.
3024
+ InstructionCost ScalarizationCost = 0 ;
3025
+ if (VF.isVector ()) {
3026
+ if (!ResultTy->isVoidTy ()) {
3027
+ for (Type *VectorTy : getContainedTypes (toVectorizedTy (ResultTy, VF))) {
3028
+ ScalarizationCost += Ctx.TTI .getScalarizationOverhead (
3029
+ cast<VectorType>(VectorTy), APInt::getAllOnes (VF.getFixedValue ()),
3030
+ /* Insert=*/ true ,
3031
+ /* Extract=*/ false , Ctx.CostKind );
3032
+ }
3033
+ }
3034
+ // Skip operands that do not require extraction/scalarization and do not
3035
+ // incur any overhead.
3036
+ SmallVector<Type *> Tys;
3037
+ SmallPtrSet<const VPValue *, 4 > UniqueOperands;
3038
+ for (auto *Op : drop_end (operands ())) {
3039
+ if (Op->isLiveIn () || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3040
+ !UniqueOperands.insert (Op).second )
3041
+ continue ;
3042
+ Tys.push_back (toVectorizedTy (Ctx.Types .inferScalarType (Op), VF));
3043
+ }
3044
+ ScalarizationCost +=
3045
+ Ctx.TTI .getOperandsScalarizationOverhead (Tys, Ctx.CostKind );
3046
+ }
3047
+
3048
+ return ScalarCallCost * (isSingleScalar () ? 1 : VF.getFixedValue ()) +
3049
+ ScalarizationCost;
3022
3050
}
3023
3051
case Instruction::Add:
3024
3052
case Instruction::Sub:
0 commit comments