-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[LV] Add support for partial reductions without a binary op #133922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a couple of nits.
llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
Outdated
Show resolved
Hide resolved
std::optional<unsigned> BinOpc; | ||
Type *ExtOpTypes[2] = {nullptr}; | ||
|
||
auto collectExtInfo = [&Exts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto collectExtInfo = [&Exts, | |
auto CollectExtInfo = [&Exts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -2108,6 +2108,97 @@ for.exit: ; preds = %for.body | |||
ret i32 %result | |||
} | |||
|
|||
define i32 @zext_add_reduc_i8_i32(ptr %a) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should those tests go in a new file? There's no dot product in the new tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I've also moved the non-dot-product tests from partial-reduce-dot-product.ll to a new partial-reduce.ll file.
@@ -5046,6 +5056,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost( | |||
if (VFMinValue == Scale) | |||
return Invalid; | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stray new line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
return Invalid; | ||
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation for the interface should probably also be updated, documenting that Opcode
and the second the second type are optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, absolutely right! I've tried to amend the documentation for the interface. Please take a look and see if it makes sense.
@@ -30,18 +30,18 @@ struct VFRange; | |||
/// accumulator). | |||
struct PartialReductionChain { | |||
PartialReductionChain(Instruction *Reduction, Instruction *ExtendA, | |||
Instruction *ExtendB, Instruction *BinOp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment above needs updating
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Consider IR such as this: for.body: %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] %gep.a = getelementptr i8, ptr %a, i64 %iv %load.a = load i8, ptr %gep.a, align 1 %ext.a = zext i8 %load.a to i32 %add = add i32 %ext.a, %accum %iv.next = add i64 %iv, 1 %exitcond.not = icmp eq i64 %iv.next, 1025 br i1 %exitcond.not, label %for.exit, label %for.body Conceptually we can vectorise this using partial reductions too, although the current loop vectoriser implementation requires the accumulation of a multiply. For AArch64 this is easily done with a udot or sdot with an identity operand, i.e. a vector of (i16 1). In order to do this I had to teach getScaledReductions that the accumulated value may come from a unary op, hence there is only one extension to consider. Similarly, I updated the vplan and AArch64 TTI cost model to understand the possible unary op.
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-llvm-transforms Author: David Sherwood (david-arm) ChangesConsider IR such as this: for.body: Conceptually we can vectorise this using partial reductions too, In order to do this I had to teach getScaledReductions that the Patch is 144.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133922.diff 9 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4835c66a7a3bc..5f3c8ff3bdfb4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1299,9 +1299,21 @@ class TargetTransformInfo {
/// \return The cost of a partial reduction, which is a reduction from a
/// vector to another vector with fewer elements of larger size. They are
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
- /// takes an accumulator and a binary operation operand that itself is fed by
- /// two extends. An example of an operation that uses a partial reduction is a
- /// dot product, which reduces two vectors to another of 4 times fewer and 4
+ /// takes an accumulator of type \p AccumType and a second vector operand to
+ /// be accumulated, whose element count is specified by \p VF. The type of
+ /// reduction is specified by \p Opcode. The second operand passed to the
+ /// intrinsic could be the result of an extend, such as sext or zext. In
+ /// this case \p BinOp is nullopt, \p InputTypeA represents the type being
+ /// extended and \p OpAExtend the operation, i.e. sign- or zero-extend.
+ /// Also, \p InputTypeB should be nullptr and OpBExtend should be None.
+ /// Alternatively, the second operand could be the result of a binary
+ /// operation performed on two extends, i.e.
+ /// mul(zext i8 %a -> i32, zext i8 %b -> i32).
+ /// In this case \p BinOp may specify the opcode of the binary operation,
+ /// \p InputTypeA and \p InputTypeB the types being extended, and
+ /// \p OpAExtend, \p OpBExtend the form of extensions. An example of an
+ /// operation that uses a partial reduction is a dot product, which reduces
+ /// two vectors in binary mul operation to another of 4 times fewer and 4
/// times larger elements.
InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 77be41b78bc7f..48424185c68de 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5177,11 +5177,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
// Sub opcodes currently only occur in chained cases.
// Independent partial reduction subtractions are still costed as an add
- if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
+ if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
+ OpAExtend == TTI::PR_None)
return Invalid;
- if (InputTypeA != InputTypeB)
+ // We only support multiply binary operations for now, and for muls we
+ // require the types being extended to be the same.
+ // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
+ // only if the i8mm or sve/streaming features are available.
+ if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
+ OpBExtend == TTI::PR_None ||
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+ !ST->isSVEorStreamingSVEAvailable())))
return Invalid;
+ assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
+ "Unexpected values for OpBExtend or InputTypeB");
EVT InputEVT = EVT::getEVT(InputTypeA);
EVT AccumEVT = EVT::getEVT(AccumType);
@@ -5228,16 +5238,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
} else
return Invalid;
- // AArch64 supports lowering mixed extensions to a usdot but only if the
- // i8mm or sve/streaming features are available.
- if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
- !ST->isSVEorStreamingSVEAvailable()))
- return Invalid;
-
- if (!BinOp || *BinOp != Instruction::Mul)
- return Invalid;
-
return Cost;
}
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0291a8bfd9674..654f3ecacf51b 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8765,15 +8765,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// something that isn't another partial reduction. This is because the
// extends are intended to be lowered along with the reduction itself.
- // Build up a set of partial reduction bin ops for efficient use checking.
- SmallSet<User *, 4> PartialReductionBinOps;
+ // Build up a set of partial reduction ops for efficient use checking.
+ SmallSet<User *, 4> PartialReductionOps;
for (const auto &[PartialRdx, _] : PartialReductionChains)
- PartialReductionBinOps.insert(PartialRdx.BinOp);
+ PartialReductionOps.insert(PartialRdx.ExtendUser);
auto ExtendIsOnlyUsedByPartialReductions =
- [&PartialReductionBinOps](Instruction *Extend) {
+ [&PartialReductionOps](Instruction *Extend) {
return all_of(Extend->users(), [&](const User *U) {
- return PartialReductionBinOps.contains(U);
+ return PartialReductionOps.contains(U);
});
};
@@ -8782,7 +8782,7 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
for (auto Pair : PartialReductionChains) {
PartialReductionChain Chain = Pair.first;
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
- ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+ (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
ScaledReductionMap.insert(std::make_pair(Chain.Reduction, Pair.second));
}
}
@@ -8790,7 +8790,6 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
bool VPRecipeBuilder::getScaledReductions(
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
-
if (!CM.TheLoop->contains(RdxExitInstr))
return false;
@@ -8819,40 +8818,70 @@ bool VPRecipeBuilder::getScaledReductions(
if (PhiOp != PHI)
return false;
- auto *BinOp = dyn_cast<BinaryOperator>(Op);
- if (!BinOp || !BinOp->hasOneUse())
- return false;
-
using namespace llvm::PatternMatch;
- // Use the side-effect of match to replace BinOp only if the pattern is
- // matched, we don't care at this point whether it actually matched.
- match(BinOp, m_Neg(m_BinOp(BinOp)));
- Value *A, *B;
- if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
- !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
- return false;
+ // If the update is a binary operator, check both of its operands to see if
+ // they are extends. Otherwise, see if the update comes directly from an
+ // extend.
+ Instruction *Exts[2] = {nullptr};
+ BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
+ std::optional<unsigned> BinOpc;
+ Type *ExtOpTypes[2] = {nullptr};
+
+ auto CollectExtInfo = [&Exts,
+ &ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
+ unsigned I = 0;
+ for (Value *OpI : Ops) {
+ Value *ExtOp;
+ if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
+ return false;
+ Exts[I] = cast<Instruction>(OpI);
+ ExtOpTypes[I] = ExtOp->getType();
+ I++;
+ }
+ return true;
+ };
+
+ if (ExtendUser) {
+ if (!ExtendUser->hasOneUse())
+ return false;
+
+ // Use the side-effect of match to replace BinOp only if the pattern is
+ // matched, we don't care at this point whether it actually matched.
+ match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
- Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
- Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+ SmallVector<Value *> Ops(ExtendUser->operands());
+ if (!CollectExtInfo(Ops))
+ return false;
+
+ BinOpc = std::make_optional(ExtendUser->getOpcode());
+ } else if (match(Update, m_Add(m_Value(), m_Value()))) {
+ // We already know the operands for Update are Op and PhiOp.
+ SmallVector<Value *> Ops({Op});
+ if (!CollectExtInfo(Ops))
+ return false;
+
+ ExtendUser = Update;
+ BinOpc = std::nullopt;
+ } else
+ return false;
TTI::PartialReductionExtendKind OpAExtend =
- TargetTransformInfo::getPartialReductionExtendKind(ExtA);
+ TargetTransformInfo::getPartialReductionExtendKind(Exts[0]);
TTI::PartialReductionExtendKind OpBExtend =
- TargetTransformInfo::getPartialReductionExtendKind(ExtB);
-
- PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
+ Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
+ : TargetTransformInfo::PR_None;
+ PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
unsigned TargetScaleFactor =
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
- A->getType()->getPrimitiveSizeInBits());
+ ExtOpTypes[0]->getPrimitiveSizeInBits());
if (LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
InstructionCost Cost = TTI->getPartialReductionCost(
- Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
- VF, OpAExtend, OpBExtend,
- std::make_optional(BinOp->getOpcode()));
+ Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
+ PHI->getType(), VF, OpAExtend, OpBExtend, BinOpc);
return Cost.isValid();
},
Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 334cfbad8bd7c..8d2d187231303 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -26,13 +26,14 @@ struct HistogramInfo;
struct VFRange;
/// A chain of instructions that form a partial reduction.
-/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
-/// accumulator).
+/// Designed to match either:
+/// reduction_bin_op (extend (A), accumulator), or
+/// reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
struct PartialReductionChain {
PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
- Instruction *ExtendB, Instruction *BinOp)
- : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
- }
+ Instruction *ExtendB, Instruction *ExtendUser)
+ : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
+ ExtendUser(ExtendUser) {}
/// The top-level binary operation that forms the reduction to a scalar
/// after the loop body.
Instruction *Reduction;
@@ -40,8 +41,8 @@ struct PartialReductionChain {
Instruction *ExtendA;
Instruction *ExtendB;
- /// The binary operation using the extends that is then reduced.
- Instruction *BinOp;
+ /// The user of the extend that is then reduced.
+ Instruction *ExtendUser;
};
/// Helper class to create VPRecipies from IR instructions.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b16a8fc563f4c..ad27e9435669f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -281,31 +281,18 @@ bool VPRecipeBase::isPhi() const {
InstructionCost
VPPartialReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
- std::optional<unsigned> Opcode = std::nullopt;
- VPValue *BinOp = getOperand(0);
+ // If the input operand is an extend then use the opcode for this recipe.
+ std::optional<unsigned> Opcode;
+ VPValue *Op = getOperand(0);
+ VPRecipeBase *OpR = Op->getDefiningRecipe();
// If the partial reduction is predicated, a select will be operand 0 rather
- // than the binary op
+ // than the extend user.
using namespace llvm::VPlanPatternMatch;
- if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
- BinOp = BinOp->getDefiningRecipe()->getOperand(1);
-
- // If BinOp is a negation, use the side effect of match to assign the actual
- // binary operation to BinOp
- match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
- VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
-
- if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
- Opcode = std::make_optional(WidenR->getOpcode());
-
- VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
- VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
-
- auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
- auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
- : BinOpR->getOperand(0));
- auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
- : BinOpR->getOperand(1));
+ if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue()))) {
+ Op = OpR->getOperand(1);
+ OpR = Op->getDefiningRecipe();
+ }
auto GetExtendKind = [](VPRecipeBase *R) {
// The extend could come from outside the plan.
@@ -321,9 +308,38 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
return TargetTransformInfo::PR_None;
};
+ Type *InputTypeA, *InputTypeB;
+ TTI::PartialReductionExtendKind ExtAType, ExtBType;
+
+ // The input may come straight from a zext or sext.
+ if (isa<VPWidenCastRecipe>(OpR)) {
+ Opcode = std::nullopt;
+ InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0));
+ InputTypeB = nullptr;
+ ExtAType = GetExtendKind(OpR);
+ ExtBType = TargetTransformInfo::PR_None;
+ } else {
+ // If BinOp is a negation, use the side effect of match to assign the actual
+ // binary operation to BinOp
+ match(Op, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Op)));
+ OpR = Op->getDefiningRecipe();
+ Opcode = std::make_optional(cast<VPWidenRecipe>(OpR)->getOpcode());
+
+ VPRecipeBase *ExtAR = OpR->getOperand(0)->getDefiningRecipe();
+ VPRecipeBase *ExtBR = OpR->getOperand(1)->getDefiningRecipe();
+
+ InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
+ : OpR->getOperand(0));
+ InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
+ : OpR->getOperand(1));
+ ExtAType = GetExtendKind(ExtAR);
+ ExtBType = GetExtendKind(ExtBR);
+ }
+
+ auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
- PhiType, VF, GetExtendKind(ExtAR),
- GetExtendKind(ExtBR), Opcode);
+ PhiType, VF, ExtAType, ExtBType,
+ Opcode);
}
void VPPartialReductionRecipe::execute(VPTransformState &State) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index a229ca8c6e6db..075742ff95b04 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -1030,6 +1030,472 @@ for.body: ; preds = %for.body.preheader,
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
}
+
+define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-NEON-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-NEON-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEON-NEXT: entry:
+; CHECK-NEON-NEXT: [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEON-NEXT: [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEON-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEON-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEON: vector.ph:
+; CHECK-NEON-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT: [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEON-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK-NEON: vector.body:
+; CHECK-NEON-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEON-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEON-NEXT: [[TMP3:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEON-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP1]], i32 0
+; CHECK-NEON-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
+; CHECK-NEON-NEXT: [[TMP5:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i32 0
+; CHECK-NEON-NEXT: [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
+; CHECK-NEON-NEXT: [[TMP6:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP3]], i32 0
+; CHECK-NEON-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
+; CHECK-NEON-NEXT: [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEON-NEXT: [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEON-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEON-NEXT: [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP9]])
+; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEON-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEON-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
+; CHECK-NEON: middle.block:
+; CHECK-NEON-NEXT: [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEON-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK-NEON: scalar.ph:
+;
+; CHECK-SVE-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-SVE-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-SVE-NEXT: entry:
+; CHECK-SVE-NEXT: [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-SVE-NEXT: [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-SVE-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-SVE-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; CHECK-SVE-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-SVE: vector.ph:
+; CHECK-SVE-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-SVE-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT: [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-SVE-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; CHECK-SVE-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK-SVE: vector.body:
+; CHECK-SVE-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP18:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]...
[truncated]
|
@llvm/pr-subscribers-llvm-analysis Author: David Sherwood (david-arm) ChangesConsider IR such as this: for.body: Conceptually we can vectorise this using partial reductions too, In order to do this I had to teach getScaledReductions that the Patch is 144.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133922.diff 9 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4835c66a7a3bc..5f3c8ff3bdfb4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1299,9 +1299,21 @@ class TargetTransformInfo {
/// \return The cost of a partial reduction, which is a reduction from a
/// vector to another vector with fewer elements of larger size. They are
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
- /// takes an accumulator and a binary operation operand that itself is fed by
- /// two extends. An example of an operation that uses a partial reduction is a
- /// dot product, which reduces two vectors to another of 4 times fewer and 4
+ /// takes an accumulator of type \p AccumType and a second vector operand to
+ /// be accumulated, whose element count is specified by \p VF. The type of
+ /// reduction is specified by \p Opcode. The second operand passed to the
+ /// intrinsic could be the result of an extend, such as sext or zext. In
+ /// this case \p BinOp is nullopt, \p InputTypeA represents the type being
+ /// extended and \p OpAExtend the operation, i.e. sign- or zero-extend.
+ /// Also, \p InputTypeB should be nullptr and OpBExtend should be None.
+ /// Alternatively, the second operand could be the result of a binary
+ /// operation performed on two extends, i.e.
+ /// mul(zext i8 %a -> i32, zext i8 %b -> i32).
+ /// In this case \p BinOp may specify the opcode of the binary operation,
+ /// \p InputTypeA and \p InputTypeB the types being extended, and
+ /// \p OpAExtend, \p OpBExtend the form of extensions. An example of an
+ /// operation that uses a partial reduction is a dot product, which reduces
+ /// two vectors in binary mul operation to another of 4 times fewer and 4
/// times larger elements.
InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 77be41b78bc7f..48424185c68de 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5177,11 +5177,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
// Sub opcodes currently only occur in chained cases.
// Independent partial reduction subtractions are still costed as an add
- if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
+ if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
+ OpAExtend == TTI::PR_None)
return Invalid;
- if (InputTypeA != InputTypeB)
+ // We only support multiply binary operations for now, and for muls we
+ // require the types being extended to be the same.
+ // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
+ // only if the i8mm or sve/streaming features are available.
+ if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
+ OpBExtend == TTI::PR_None ||
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+ !ST->isSVEorStreamingSVEAvailable())))
return Invalid;
+ assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
+ "Unexpected values for OpBExtend or InputTypeB");
EVT InputEVT = EVT::getEVT(InputTypeA);
EVT AccumEVT = EVT::getEVT(AccumType);
@@ -5228,16 +5238,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
} else
return Invalid;
- // AArch64 supports lowering mixed extensions to a usdot but only if the
- // i8mm or sve/streaming features are available.
- if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
- !ST->isSVEorStreamingSVEAvailable()))
- return Invalid;
-
- if (!BinOp || *BinOp != Instruction::Mul)
- return Invalid;
-
return Cost;
}
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0291a8bfd9674..654f3ecacf51b 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8765,15 +8765,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// something that isn't another partial reduction. This is because the
// extends are intended to be lowered along with the reduction itself.
- // Build up a set of partial reduction bin ops for efficient use checking.
- SmallSet<User *, 4> PartialReductionBinOps;
+ // Build up a set of partial reduction ops for efficient use checking.
+ SmallSet<User *, 4> PartialReductionOps;
for (const auto &[PartialRdx, _] : PartialReductionChains)
- PartialReductionBinOps.insert(PartialRdx.BinOp);
+ PartialReductionOps.insert(PartialRdx.ExtendUser);
auto ExtendIsOnlyUsedByPartialReductions =
- [&PartialReductionBinOps](Instruction *Extend) {
+ [&PartialReductionOps](Instruction *Extend) {
return all_of(Extend->users(), [&](const User *U) {
- return PartialReductionBinOps.contains(U);
+ return PartialReductionOps.contains(U);
});
};
@@ -8782,7 +8782,7 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
for (auto Pair : PartialReductionChains) {
PartialReductionChain Chain = Pair.first;
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
- ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+ (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
ScaledReductionMap.insert(std::make_pair(Chain.Reduction, Pair.second));
}
}
@@ -8790,7 +8790,6 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
bool VPRecipeBuilder::getScaledReductions(
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
-
if (!CM.TheLoop->contains(RdxExitInstr))
return false;
@@ -8819,40 +8818,70 @@ bool VPRecipeBuilder::getScaledReductions(
if (PhiOp != PHI)
return false;
- auto *BinOp = dyn_cast<BinaryOperator>(Op);
- if (!BinOp || !BinOp->hasOneUse())
- return false;
-
using namespace llvm::PatternMatch;
- // Use the side-effect of match to replace BinOp only if the pattern is
- // matched, we don't care at this point whether it actually matched.
- match(BinOp, m_Neg(m_BinOp(BinOp)));
- Value *A, *B;
- if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
- !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
- return false;
+ // If the update is a binary operator, check both of its operands to see if
+ // they are extends. Otherwise, see if the update comes directly from an
+ // extend.
+ Instruction *Exts[2] = {nullptr};
+ BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
+ std::optional<unsigned> BinOpc;
+ Type *ExtOpTypes[2] = {nullptr};
+
+ auto CollectExtInfo = [&Exts,
+ &ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
+ unsigned I = 0;
+ for (Value *OpI : Ops) {
+ Value *ExtOp;
+ if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
+ return false;
+ Exts[I] = cast<Instruction>(OpI);
+ ExtOpTypes[I] = ExtOp->getType();
+ I++;
+ }
+ return true;
+ };
+
+ if (ExtendUser) {
+ if (!ExtendUser->hasOneUse())
+ return false;
+
+ // Use the side-effect of match to replace BinOp only if the pattern is
+ // matched, we don't care at this point whether it actually matched.
+ match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
- Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
- Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+ SmallVector<Value *> Ops(ExtendUser->operands());
+ if (!CollectExtInfo(Ops))
+ return false;
+
+ BinOpc = std::make_optional(ExtendUser->getOpcode());
+ } else if (match(Update, m_Add(m_Value(), m_Value()))) {
+ // We already know the operands for Update are Op and PhiOp.
+ SmallVector<Value *> Ops({Op});
+ if (!CollectExtInfo(Ops))
+ return false;
+
+ ExtendUser = Update;
+ BinOpc = std::nullopt;
+ } else
+ return false;
TTI::PartialReductionExtendKind OpAExtend =
- TargetTransformInfo::getPartialReductionExtendKind(ExtA);
+ TargetTransformInfo::getPartialReductionExtendKind(Exts[0]);
TTI::PartialReductionExtendKind OpBExtend =
- TargetTransformInfo::getPartialReductionExtendKind(ExtB);
-
- PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
+ Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
+ : TargetTransformInfo::PR_None;
+ PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
unsigned TargetScaleFactor =
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
- A->getType()->getPrimitiveSizeInBits());
+ ExtOpTypes[0]->getPrimitiveSizeInBits());
if (LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
InstructionCost Cost = TTI->getPartialReductionCost(
- Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
- VF, OpAExtend, OpBExtend,
- std::make_optional(BinOp->getOpcode()));
+ Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
+ PHI->getType(), VF, OpAExtend, OpBExtend, BinOpc);
return Cost.isValid();
},
Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 334cfbad8bd7c..8d2d187231303 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -26,13 +26,14 @@ struct HistogramInfo;
struct VFRange;
/// A chain of instructions that form a partial reduction.
-/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
-/// accumulator).
+/// Designed to match either:
+/// reduction_bin_op (extend (A), accumulator), or
+/// reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
struct PartialReductionChain {
PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
- Instruction *ExtendB, Instruction *BinOp)
- : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
- }
+ Instruction *ExtendB, Instruction *ExtendUser)
+ : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
+ ExtendUser(ExtendUser) {}
/// The top-level binary operation that forms the reduction to a scalar
/// after the loop body.
Instruction *Reduction;
@@ -40,8 +41,8 @@ struct PartialReductionChain {
Instruction *ExtendA;
Instruction *ExtendB;
- /// The binary operation using the extends that is then reduced.
- Instruction *BinOp;
+ /// The user of the extend that is then reduced.
+ Instruction *ExtendUser;
};
/// Helper class to create VPRecipies from IR instructions.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b16a8fc563f4c..ad27e9435669f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -281,31 +281,18 @@ bool VPRecipeBase::isPhi() const {
InstructionCost
VPPartialReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
- std::optional<unsigned> Opcode = std::nullopt;
- VPValue *BinOp = getOperand(0);
+ // If the input operand is an extend then use the opcode for this recipe.
+ std::optional<unsigned> Opcode;
+ VPValue *Op = getOperand(0);
+ VPRecipeBase *OpR = Op->getDefiningRecipe();
// If the partial reduction is predicated, a select will be operand 0 rather
- // than the binary op
+ // than the extend user.
using namespace llvm::VPlanPatternMatch;
- if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
- BinOp = BinOp->getDefiningRecipe()->getOperand(1);
-
- // If BinOp is a negation, use the side effect of match to assign the actual
- // binary operation to BinOp
- match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
- VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
-
- if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
- Opcode = std::make_optional(WidenR->getOpcode());
-
- VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
- VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
-
- auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
- auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
- : BinOpR->getOperand(0));
- auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
- : BinOpR->getOperand(1));
+ if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue()))) {
+ Op = OpR->getOperand(1);
+ OpR = Op->getDefiningRecipe();
+ }
auto GetExtendKind = [](VPRecipeBase *R) {
// The extend could come from outside the plan.
@@ -321,9 +308,38 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
return TargetTransformInfo::PR_None;
};
+ Type *InputTypeA, *InputTypeB;
+ TTI::PartialReductionExtendKind ExtAType, ExtBType;
+
+ // The input may come straight from a zext or sext.
+ if (isa<VPWidenCastRecipe>(OpR)) {
+ Opcode = std::nullopt;
+ InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0));
+ InputTypeB = nullptr;
+ ExtAType = GetExtendKind(OpR);
+ ExtBType = TargetTransformInfo::PR_None;
+ } else {
+ // If BinOp is a negation, use the side effect of match to assign the actual
+ // binary operation to BinOp
+ match(Op, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Op)));
+ OpR = Op->getDefiningRecipe();
+ Opcode = std::make_optional(cast<VPWidenRecipe>(OpR)->getOpcode());
+
+ VPRecipeBase *ExtAR = OpR->getOperand(0)->getDefiningRecipe();
+ VPRecipeBase *ExtBR = OpR->getOperand(1)->getDefiningRecipe();
+
+ InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
+ : OpR->getOperand(0));
+ InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
+ : OpR->getOperand(1));
+ ExtAType = GetExtendKind(ExtAR);
+ ExtBType = GetExtendKind(ExtBR);
+ }
+
+ auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
- PhiType, VF, GetExtendKind(ExtAR),
- GetExtendKind(ExtBR), Opcode);
+ PhiType, VF, ExtAType, ExtBType,
+ Opcode);
}
void VPPartialReductionRecipe::execute(VPTransformState &State) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index a229ca8c6e6db..075742ff95b04 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -1030,6 +1030,472 @@ for.body: ; preds = %for.body.preheader,
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
}
+
+define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-NEON-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-NEON-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEON-NEXT: entry:
+; CHECK-NEON-NEXT: [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEON-NEXT: [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEON-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEON-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEON: vector.ph:
+; CHECK-NEON-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT: [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEON-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK-NEON: vector.body:
+; CHECK-NEON-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT: [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEON-NEXT: [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEON-NEXT: [[TMP3:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEON-NEXT: [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP1]], i32 0
+; CHECK-NEON-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
+; CHECK-NEON-NEXT: [[TMP5:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i32 0
+; CHECK-NEON-NEXT: [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
+; CHECK-NEON-NEXT: [[TMP6:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP3]], i32 0
+; CHECK-NEON-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
+; CHECK-NEON-NEXT: [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEON-NEXT: [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEON-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEON-NEXT: [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
+; CHECK-NEON-NEXT: [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP9]])
+; CHECK-NEON-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEON-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEON-NEXT: br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
+; CHECK-NEON: middle.block:
+; CHECK-NEON-NEXT: [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-NEON-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEON-NEXT: br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK-NEON: scalar.ph:
+;
+; CHECK-SVE-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-SVE-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-SVE-NEXT: entry:
+; CHECK-SVE-NEXT: [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-SVE-NEXT: [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-SVE-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-SVE-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; CHECK-SVE-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-SVE: vector.ph:
+; CHECK-SVE-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-SVE-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT: [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-SVE-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; CHECK-SVE-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK-SVE: vector.body:
+; CHECK-SVE-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP18:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]...
[truncated]
|
Gentle ping. :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious if/how this will interact with #136173?
Would it potentially help to simplify computeCost & co?
I think we'll want to turn a partial reduction without a bin op into a |
@fhahn @SamTebbs33 Do you think that #136173 is likely to land soon? Ideally I'd like to progress this patch within the next few weeks. It looks like #136173 also depends upon #113903, which looks like it might still take some time. |
If both of those PRs are close to landing I could try to check them out and see how my PR interacts with them, but I'd rather do that once I know they're clos(ish) to the finish line. |
I think that 136173 is very very close to being approved and there shouldn't be any big changes to it on the way. |
Consider IR such as this:
for.body:
%iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
%accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
%gep.a = getelementptr i8, ptr %a, i64 %iv
%load.a = load i8, ptr %gep.a, align 1
%ext.a = zext i8 %load.a to i32
%add = add i32 %ext.a, %accum
%iv.next = add i64 %iv, 1
%exitcond.not = icmp eq i64 %iv.next, 1025
br i1 %exitcond.not, label %for.exit, label %for.body
Conceptually we can vectorise this using partial reductions too,
although the current loop vectoriser implementation requires the
accumulation of a multiply. For AArch64 this is easily done with
a udot or sdot with an identity operand, i.e. a vector of (i16 1).
In order to do this I had to teach getScaledReductions that the
accumulated value may come from a unary op, hence there is only
one extension to consider. Similarly, I updated the vplan and
AArch64 TTI cost model to understand the possible unary op.