Skip to content

Commit a758232

Browse files
committed
[LV] Add support for partial reductions without a binary op
Consider IR such as this: for.body: %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] %gep.a = getelementptr i8, ptr %a, i64 %iv %load.a = load i8, ptr %gep.a, align 1 %ext.a = zext i8 %load.a to i32 %add = add i32 %ext.a, %accum %iv.next = add i64 %iv, 1 %exitcond.not = icmp eq i64 %iv.next, 1025 br i1 %exitcond.not, label %for.exit, label %for.body Conceptually we can vectorise this using partial reductions too, although the current loop vectoriser implementation requires the accumulation of a multiply. For AArch64 this is easily done with a udot or sdot with an identity operand, i.e. a vector of (i16 1). In order to do this I had to teach getScaledReductions that the accumulated value may come from a unary op, hence there is only one extension to consider. Similarly, I updated the vplan and AArch64 TTI cost model to understand the possible unary op.
1 parent 84f6dd3 commit a758232

File tree

8 files changed

+256
-200
lines changed

8 files changed

+256
-200
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+15-3
Original file line numberDiff line numberDiff line change
@@ -1299,9 +1299,21 @@ class TargetTransformInfo {
12991299
/// \return The cost of a partial reduction, which is a reduction from a
13001300
/// vector to another vector with fewer elements of larger size. They are
13011301
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
1302-
/// takes an accumulator and a binary operation operand that itself is fed by
1303-
/// two extends. An example of an operation that uses a partial reduction is a
1304-
/// dot product, which reduces two vectors to another of 4 times fewer and 4
1302+
/// takes an accumulator of type \p AccumType and a second vector operand to
1303+
/// be accumulated, whose element count is specified by \p VF. The type of
1304+
/// reduction is specified by \p Opcode. The second operand passed to the
1305+
/// intrinsic could be the result of an extend, such as sext or zext. In
1306+
/// this case \p BinOp is nullopt, \p InputTypeA represents the type being
1307+
/// extended and \p OpAExtend the operation, i.e. sign- or zero-extend.
1308+
/// Also, \p InputTypeB should be nullptr and OpBExtend should be None.
1309+
/// Alternatively, the second operand could be the result of a binary
1310+
/// operation performed on two extends, i.e.
1311+
/// mul(zext i8 %a -> i32, zext i8 %b -> i32).
1312+
/// In this case \p BinOp may specify the opcode of the binary operation,
1313+
/// \p InputTypeA and \p InputTypeB the types being extended, and
1314+
/// \p OpAExtend, \p OpBExtend the form of extensions. An example of an
1315+
/// operation that uses a partial reduction is a dot product, which reduces
1316+
/// two vectors in binary mul operation to another of 4 times fewer and 4
13051317
/// times larger elements.
13061318
InstructionCost
13071319
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -5177,11 +5177,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
51775177

51785178
// Sub opcodes currently only occur in chained cases.
51795179
// Independent partial reduction subtractions are still costed as an add
5180-
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
5180+
if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
5181+
OpAExtend == TTI::PR_None)
51815182
return Invalid;
51825183

5183-
if (InputTypeA != InputTypeB)
5184+
// We only support multiply binary operations for now, and for muls we
5185+
// require the types being extended to be the same.
5186+
// NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
5187+
// only if the i8mm or sve/streaming features are available.
5188+
if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
5189+
OpBExtend == TTI::PR_None ||
5190+
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
5191+
!ST->isSVEorStreamingSVEAvailable())))
51845192
return Invalid;
5193+
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5194+
"Unexpected values for OpBExtend or InputTypeB");
51855195

51865196
EVT InputEVT = EVT::getEVT(InputTypeA);
51875197
EVT AccumEVT = EVT::getEVT(AccumType);
@@ -5228,16 +5238,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
52285238
} else
52295239
return Invalid;
52305240

5231-
// AArch64 supports lowering mixed extensions to a usdot but only if the
5232-
// i8mm or sve/streaming features are available.
5233-
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
5234-
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
5235-
!ST->isSVEorStreamingSVEAvailable()))
5236-
return Invalid;
5237-
5238-
if (!BinOp || *BinOp != Instruction::Mul)
5239-
return Invalid;
5240-
52415241
return Cost;
52425242
}
52435243

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+57-28
Original file line numberDiff line numberDiff line change
@@ -8765,15 +8765,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87658765
// something that isn't another partial reduction. This is because the
87668766
// extends are intended to be lowered along with the reduction itself.
87678767

8768-
// Build up a set of partial reduction bin ops for efficient use checking.
8769-
SmallSet<User *, 4> PartialReductionBinOps;
8768+
// Build up a set of partial reduction ops for efficient use checking.
8769+
SmallSet<User *, 4> PartialReductionOps;
87708770
for (const auto &[PartialRdx, _] : PartialReductionChains)
8771-
PartialReductionBinOps.insert(PartialRdx.BinOp);
8771+
PartialReductionOps.insert(PartialRdx.ExtendUser);
87728772

87738773
auto ExtendIsOnlyUsedByPartialReductions =
8774-
[&PartialReductionBinOps](Instruction *Extend) {
8774+
[&PartialReductionOps](Instruction *Extend) {
87758775
return all_of(Extend->users(), [&](const User *U) {
8776-
return PartialReductionBinOps.contains(U);
8776+
return PartialReductionOps.contains(U);
87778777
});
87788778
};
87798779

@@ -8782,15 +8782,14 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87828782
for (auto Pair : PartialReductionChains) {
87838783
PartialReductionChain Chain = Pair.first;
87848784
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8785-
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8785+
(!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
87868786
ScaledReductionMap.insert(std::make_pair(Chain.Reduction, Pair.second));
87878787
}
87888788
}
87898789

87908790
bool VPRecipeBuilder::getScaledReductions(
87918791
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
87928792
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
8793-
87948793
if (!CM.TheLoop->contains(RdxExitInstr))
87958794
return false;
87968795

@@ -8819,40 +8818,70 @@ bool VPRecipeBuilder::getScaledReductions(
88198818
if (PhiOp != PHI)
88208819
return false;
88218820

8822-
auto *BinOp = dyn_cast<BinaryOperator>(Op);
8823-
if (!BinOp || !BinOp->hasOneUse())
8824-
return false;
8825-
88268821
using namespace llvm::PatternMatch;
8827-
// Use the side-effect of match to replace BinOp only if the pattern is
8828-
// matched, we don't care at this point whether it actually matched.
8829-
match(BinOp, m_Neg(m_BinOp(BinOp)));
88308822

8831-
Value *A, *B;
8832-
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8833-
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8834-
return false;
8823+
// If the update is a binary operator, check both of its operands to see if
8824+
// they are extends. Otherwise, see if the update comes directly from an
8825+
// extend.
8826+
Instruction *Exts[2] = {nullptr};
8827+
BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
8828+
std::optional<unsigned> BinOpc;
8829+
Type *ExtOpTypes[2] = {nullptr};
8830+
8831+
auto CollectExtInfo = [&Exts,
8832+
&ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
8833+
unsigned I = 0;
8834+
for (Value *OpI : Ops) {
8835+
Value *ExtOp;
8836+
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
8837+
return false;
8838+
Exts[I] = cast<Instruction>(OpI);
8839+
ExtOpTypes[I] = ExtOp->getType();
8840+
I++;
8841+
}
8842+
return true;
8843+
};
8844+
8845+
if (ExtendUser) {
8846+
if (!ExtendUser->hasOneUse())
8847+
return false;
8848+
8849+
// Use the side-effect of match to replace BinOp only if the pattern is
8850+
// matched, we don't care at this point whether it actually matched.
8851+
match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
88358852

8836-
Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8837-
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8853+
SmallVector<Value *> Ops(ExtendUser->operands());
8854+
if (!CollectExtInfo(Ops))
8855+
return false;
8856+
8857+
BinOpc = std::make_optional(ExtendUser->getOpcode());
8858+
} else if (match(Update, m_Add(m_Value(), m_Value()))) {
8859+
// We already know the operands for Update are Op and PhiOp.
8860+
SmallVector<Value *> Ops({Op});
8861+
if (!CollectExtInfo(Ops))
8862+
return false;
8863+
8864+
ExtendUser = Update;
8865+
BinOpc = std::nullopt;
8866+
} else
8867+
return false;
88388868

88398869
TTI::PartialReductionExtendKind OpAExtend =
8840-
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
8870+
TargetTransformInfo::getPartialReductionExtendKind(Exts[0]);
88418871
TTI::PartialReductionExtendKind OpBExtend =
8842-
TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8843-
8844-
PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
8872+
Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
8873+
: TargetTransformInfo::PR_None;
8874+
PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
88458875

88468876
unsigned TargetScaleFactor =
88478877
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8848-
A->getType()->getPrimitiveSizeInBits());
8878+
ExtOpTypes[0]->getPrimitiveSizeInBits());
88498879

88508880
if (LoopVectorizationPlanner::getDecisionAndClampRange(
88518881
[&](ElementCount VF) {
88528882
InstructionCost Cost = TTI->getPartialReductionCost(
8853-
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
8854-
VF, OpAExtend, OpBExtend,
8855-
std::make_optional(BinOp->getOpcode()));
8883+
Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
8884+
PHI->getType(), VF, OpAExtend, OpBExtend, BinOpc);
88568885
return Cost.isValid();
88578886
},
88588887
Range)) {

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

+8-7
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,23 @@ struct HistogramInfo;
2626
struct VFRange;
2727

2828
/// A chain of instructions that form a partial reduction.
29-
/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
30-
/// accumulator).
29+
/// Designed to match either:
30+
/// reduction_bin_op (extend (A), accumulator), or
31+
/// reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
3132
struct PartialReductionChain {
3233
PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
33-
Instruction *ExtendB, Instruction *BinOp)
34-
: Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
35-
}
34+
Instruction *ExtendB, Instruction *ExtendUser)
35+
: Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
36+
ExtendUser(ExtendUser) {}
3637
/// The top-level binary operation that forms the reduction to a scalar
3738
/// after the loop body.
3839
Instruction *Reduction;
3940
/// The extension of each of the inner binary operation's operands.
4041
Instruction *ExtendA;
4142
Instruction *ExtendB;
4243

43-
/// The binary operation using the extends that is then reduced.
44-
Instruction *BinOp;
44+
/// The user of the extend that is then reduced.
45+
Instruction *ExtendUser;
4546
};
4647

4748
/// Helper class to create VPRecipies from IR instructions.

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

+40-24
Original file line numberDiff line numberDiff line change
@@ -281,31 +281,18 @@ bool VPRecipeBase::isPhi() const {
281281
InstructionCost
282282
VPPartialReductionRecipe::computeCost(ElementCount VF,
283283
VPCostContext &Ctx) const {
284-
std::optional<unsigned> Opcode = std::nullopt;
285-
VPValue *BinOp = getOperand(0);
284+
// If the input operand is an extend then use the opcode for this recipe.
285+
std::optional<unsigned> Opcode;
286+
VPValue *Op = getOperand(0);
287+
VPRecipeBase *OpR = Op->getDefiningRecipe();
286288

287289
// If the partial reduction is predicated, a select will be operand 0 rather
288-
// than the binary op
290+
// than the extend user.
289291
using namespace llvm::VPlanPatternMatch;
290-
if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
291-
BinOp = BinOp->getDefiningRecipe()->getOperand(1);
292-
293-
// If BinOp is a negation, use the side effect of match to assign the actual
294-
// binary operation to BinOp
295-
match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
296-
VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
297-
298-
if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
299-
Opcode = std::make_optional(WidenR->getOpcode());
300-
301-
VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
302-
VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
303-
304-
auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
305-
auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
306-
: BinOpR->getOperand(0));
307-
auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
308-
: BinOpR->getOperand(1));
292+
if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue()))) {
293+
Op = OpR->getOperand(1);
294+
OpR = Op->getDefiningRecipe();
295+
}
309296

310297
auto GetExtendKind = [](VPRecipeBase *R) {
311298
// The extend could come from outside the plan.
@@ -321,9 +308,38 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
321308
return TargetTransformInfo::PR_None;
322309
};
323310

311+
Type *InputTypeA, *InputTypeB;
312+
TTI::PartialReductionExtendKind ExtAType, ExtBType;
313+
314+
// The input may come straight from a zext or sext.
315+
if (isa<VPWidenCastRecipe>(OpR)) {
316+
Opcode = std::nullopt;
317+
InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0));
318+
InputTypeB = nullptr;
319+
ExtAType = GetExtendKind(OpR);
320+
ExtBType = TargetTransformInfo::PR_None;
321+
} else {
322+
// If BinOp is a negation, use the side effect of match to assign the actual
323+
// binary operation to BinOp
324+
match(Op, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Op)));
325+
OpR = Op->getDefiningRecipe();
326+
Opcode = std::make_optional(cast<VPWidenRecipe>(OpR)->getOpcode());
327+
328+
VPRecipeBase *ExtAR = OpR->getOperand(0)->getDefiningRecipe();
329+
VPRecipeBase *ExtBR = OpR->getOperand(1)->getDefiningRecipe();
330+
331+
InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
332+
: OpR->getOperand(0));
333+
InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
334+
: OpR->getOperand(1));
335+
ExtAType = GetExtendKind(ExtAR);
336+
ExtBType = GetExtendKind(ExtBR);
337+
}
338+
339+
auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
324340
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
325-
PhiType, VF, GetExtendKind(ExtAR),
326-
GetExtendKind(ExtBR), Opcode);
341+
PhiType, VF, ExtAType, ExtBType,
342+
Opcode);
327343
}
328344

329345
void VPPartialReductionRecipe::execute(VPTransformState &State) {

0 commit comments

Comments
 (0)