Skip to content

Commit 93dfdaf

Browse files
committed
[VPlan] Compute cost of replicating calls in VPlan. (NFCI)
Implement computing the scalarization overhead for replicating calls in VPlan, matching the legacy cost model. Depends on llvm#154126.
1 parent 4e6c88b commit 93dfdaf

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3002,13 +3002,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30023002
// instruction cost.
30033003
return 0;
30043004
case Instruction::Call: {
3005-
if (!isSingleScalar()) {
3006-
// TODO: Handle remaining call costs here as well.
3007-
if (VF.isScalable())
3008-
return InstructionCost::getInvalid();
3009-
break;
3010-
}
3011-
30123005
auto *CalledFn =
30133006
cast<Function>(getOperand(getNumOperands() - 1)->getLiveInIRValue());
30143007
if (CalledFn->isIntrinsic())
@@ -3017,8 +3010,43 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30173010
SmallVector<Type *, 4> Tys;
30183011
for (VPValue *ArgOp : drop_end(operands()))
30193012
Tys.push_back(Ctx.Types.inferScalarType(ArgOp));
3013+
30203014
Type *ResultTy = Ctx.Types.inferScalarType(this);
3021-
return Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind);
3015+
InstructionCost ScalarCallCost =
3016+
Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind);
3017+
if (isSingleScalar())
3018+
return ScalarCallCost;
3019+
3020+
if (VF.isScalable())
3021+
return InstructionCost::getInvalid();
3022+
3023+
// Compute the cost of scalarizing the result and operands if needed.
3024+
InstructionCost ScalarizationCost = 0;
3025+
if (VF.isVector()) {
3026+
if (!ResultTy->isVoidTy()) {
3027+
for (Type *VectorTy : getContainedTypes(toVectorizedTy(ResultTy, VF))) {
3028+
ScalarizationCost += Ctx.TTI.getScalarizationOverhead(
3029+
cast<VectorType>(VectorTy), APInt::getAllOnes(VF.getFixedValue()),
3030+
/*Insert=*/true,
3031+
/*Extract=*/false, Ctx.CostKind);
3032+
}
3033+
}
3034+
// Skip operands that do not require extraction/scalarization and do not
3035+
// incur any overhead.
3036+
SmallVector<Type *> Tys;
3037+
SmallPtrSet<const VPValue *, 4> UniqueOperands;
3038+
for (auto *Op : drop_end(operands())) {
3039+
if (Op->isLiveIn() || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3040+
!UniqueOperands.insert(Op).second)
3041+
continue;
3042+
Tys.push_back(toVectorizedTy(Ctx.Types.inferScalarType(Op), VF));
3043+
}
3044+
ScalarizationCost +=
3045+
Ctx.TTI.getOperandsScalarizationOverhead(Tys, Ctx.CostKind);
3046+
}
3047+
3048+
return ScalarCallCost * (isSingleScalar() ? 1 : VF.getFixedValue()) +
3049+
ScalarizationCost;
30223050
}
30233051
case Instruction::Add:
30243052
case Instruction::Sub:

0 commit comments

Comments
 (0)