Skip to content

[LV] Add support for partial reductions without a binary op #133922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

david-arm
Copy link
Contributor

Consider IR such as this:

for.body:
%iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
%accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
%gep.a = getelementptr i8, ptr %a, i64 %iv
%load.a = load i8, ptr %gep.a, align 1
%ext.a = zext i8 %load.a to i32
%add = add i32 %ext.a, %accum
%iv.next = add i64 %iv, 1
%exitcond.not = icmp eq i64 %iv.next, 1025
br i1 %exitcond.not, label %for.exit, label %for.body

Conceptually we can vectorise this using partial reductions too,
although the current loop vectoriser implementation requires the
accumulation of a multiply. For AArch64 this is easily done with
a udot or sdot with an identity operand, i.e. a vector of (i16 1).

In order to do this I had to teach getScaledReductions that the
accumulated value may come from a unary op, hence there is only
one extension to consider. Similarly, I updated the vplan and
AArch64 TTI cost model to understand the possible unary op.

@david-arm david-arm requested review from fhahn, SamTebbs33 and MacDue April 1, 2025 15:15
Copy link

github-actions bot commented Apr 1, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a couple of nits.

std::optional<unsigned> BinOpc;
Type *ExtOpTypes[2] = {nullptr};

auto collectExtInfo = [&Exts,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto collectExtInfo = [&Exts,
auto CollectExtInfo = [&Exts,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -2108,6 +2108,97 @@ for.exit: ; preds = %for.body
ret i32 %result
}

define i32 @zext_add_reduc_i8_i32(ptr %a) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should those tests go in a new file? There's no dot product in the new tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I've also moved the non-dot-product tests from partial-reduce-dot-product.ll to a new partial-reduce.ll file.

@@ -5046,6 +5056,7 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
if (VFMinValue == Scale)
return Invalid;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stray new line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return Invalid;
assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for the interface should probably also be updated, documenting that Opcode and the second the second type are optional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, absolutely right! I've tried to amend the documentation for the interface. Please take a look and see if it makes sense.

@@ -30,18 +30,18 @@ struct VFRange;
/// accumulator).
struct PartialReductionChain {
PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
Instruction *ExtendB, Instruction *BinOp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment above needs updating

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Consider IR such as this:

for.body:
  %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
  %accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
  %gep.a = getelementptr i8, ptr %a, i64 %iv
  %load.a = load i8, ptr %gep.a, align 1
  %ext.a = zext i8 %load.a to i32
  %add = add i32 %ext.a, %accum
  %iv.next = add i64 %iv, 1
  %exitcond.not = icmp eq i64 %iv.next, 1025
  br i1 %exitcond.not, label %for.exit, label %for.body

Conceptually we can vectorise this using partial reductions too,
although the current loop vectoriser implementation requires the
accumulation of a multiply. For AArch64 this is easily done with
a udot or sdot with an identity operand, i.e. a vector of (i16 1).

In order to do this I had to teach getScaledReductions that the
accumulated value may come from a unary op, hence there is only
one extension to consider. Similarly, I updated the vplan and
AArch64 TTI cost model to understand the possible unary op.
@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2025

@llvm/pr-subscribers-vectorizers
@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-transforms

Author: David Sherwood (david-arm)

Changes

Consider IR such as this:

for.body:
%iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
%accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
%gep.a = getelementptr i8, ptr %a, i64 %iv
%load.a = load i8, ptr %gep.a, align 1
%ext.a = zext i8 %load.a to i32
%add = add i32 %ext.a, %accum
%iv.next = add i64 %iv, 1
%exitcond.not = icmp eq i64 %iv.next, 1025
br i1 %exitcond.not, label %for.exit, label %for.body

Conceptually we can vectorise this using partial reductions too,
although the current loop vectoriser implementation requires the
accumulation of a multiply. For AArch64 this is easily done with
a udot or sdot with an identity operand, i.e. a vector of (i16 1).

In order to do this I had to teach getScaledReductions that the
accumulated value may come from a unary op, hence there is only
one extension to consider. Similarly, I updated the vplan and
AArch64 TTI cost model to understand the possible unary op.


Patch is 144.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133922.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+15-3)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+12-12)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+57-28)
  • (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+8-7)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+40-24)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll (+466)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (-506)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-neon.ll (+98)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll (+826)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4835c66a7a3bc..5f3c8ff3bdfb4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1299,9 +1299,21 @@ class TargetTransformInfo {
   /// \return The cost of a partial reduction, which is a reduction from a
   /// vector to another vector with fewer elements of larger size. They are
   /// represented by the llvm.experimental.partial.reduce.add intrinsic, which
-  /// takes an accumulator and a binary operation operand that itself is fed by
-  /// two extends. An example of an operation that uses a partial reduction is a
-  /// dot product, which reduces two vectors to another of 4 times fewer and 4
+  /// takes an accumulator of type \p AccumType and a second vector operand to
+  /// be accumulated, whose element count is specified by \p VF. The type of
+  /// reduction is specified by \p Opcode. The second operand passed to the
+  /// intrinsic could be the result of an extend, such as sext or zext. In
+  /// this case \p BinOp is nullopt, \p InputTypeA represents the type being
+  /// extended and \p OpAExtend the operation, i.e. sign- or zero-extend.
+  /// Also, \p InputTypeB should be nullptr and OpBExtend should be None.
+  /// Alternatively, the second operand could be the result of a binary
+  /// operation performed on two extends, i.e.
+  ///   mul(zext i8 %a -> i32, zext i8 %b -> i32).
+  /// In this case \p BinOp may specify the opcode of the binary operation,
+  /// \p InputTypeA and \p InputTypeB the types being extended, and
+  /// \p OpAExtend, \p OpBExtend the form of extensions. An example of an
+  /// operation that uses a partial reduction is a dot product, which reduces
+  /// two vectors in binary mul operation to another of 4 times fewer and 4
   /// times larger elements.
   InstructionCost
   getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 77be41b78bc7f..48424185c68de 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5177,11 +5177,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
 
   // Sub opcodes currently only occur in chained cases.
   // Independent partial reduction subtractions are still costed as an add
-  if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
+  if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
+      OpAExtend == TTI::PR_None)
     return Invalid;
 
-  if (InputTypeA != InputTypeB)
+  // We only support multiply binary operations for now, and for muls we
+  // require the types being extended to be the same.
+  // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
+  // only if the i8mm or sve/streaming features are available.
+  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
+                OpBExtend == TTI::PR_None ||
+                (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+                 !ST->isSVEorStreamingSVEAvailable())))
     return Invalid;
+  assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
+         "Unexpected values for OpBExtend or InputTypeB");
 
   EVT InputEVT = EVT::getEVT(InputTypeA);
   EVT AccumEVT = EVT::getEVT(AccumType);
@@ -5228,16 +5238,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
   } else
     return Invalid;
 
-  // AArch64 supports lowering mixed extensions to a usdot but only if the
-  // i8mm or sve/streaming features are available.
-  if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
-      (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-       !ST->isSVEorStreamingSVEAvailable()))
-    return Invalid;
-
-  if (!BinOp || *BinOp != Instruction::Mul)
-    return Invalid;
-
   return Cost;
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0291a8bfd9674..654f3ecacf51b 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8765,15 +8765,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   // something that isn't another partial reduction. This is because the
   // extends are intended to be lowered along with the reduction itself.
 
-  // Build up a set of partial reduction bin ops for efficient use checking.
-  SmallSet<User *, 4> PartialReductionBinOps;
+  // Build up a set of partial reduction ops for efficient use checking.
+  SmallSet<User *, 4> PartialReductionOps;
   for (const auto &[PartialRdx, _] : PartialReductionChains)
-    PartialReductionBinOps.insert(PartialRdx.BinOp);
+    PartialReductionOps.insert(PartialRdx.ExtendUser);
 
   auto ExtendIsOnlyUsedByPartialReductions =
-      [&PartialReductionBinOps](Instruction *Extend) {
+      [&PartialReductionOps](Instruction *Extend) {
         return all_of(Extend->users(), [&](const User *U) {
-          return PartialReductionBinOps.contains(U);
+          return PartialReductionOps.contains(U);
         });
       };
 
@@ -8782,7 +8782,7 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   for (auto Pair : PartialReductionChains) {
     PartialReductionChain Chain = Pair.first;
     if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
-        ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+        (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
       ScaledReductionMap.insert(std::make_pair(Chain.Reduction, Pair.second));
   }
 }
@@ -8790,7 +8790,6 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
 bool VPRecipeBuilder::getScaledReductions(
     Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
     SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
-
   if (!CM.TheLoop->contains(RdxExitInstr))
     return false;
 
@@ -8819,40 +8818,70 @@ bool VPRecipeBuilder::getScaledReductions(
   if (PhiOp != PHI)
     return false;
 
-  auto *BinOp = dyn_cast<BinaryOperator>(Op);
-  if (!BinOp || !BinOp->hasOneUse())
-    return false;
-
   using namespace llvm::PatternMatch;
-  // Use the side-effect of match to replace BinOp only if the pattern is
-  // matched, we don't care at this point whether it actually matched.
-  match(BinOp, m_Neg(m_BinOp(BinOp)));
 
-  Value *A, *B;
-  if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
-      !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
-    return false;
+  // If the update is a binary operator, check both of its operands to see if
+  // they are extends. Otherwise, see if the update comes directly from an
+  // extend.
+  Instruction *Exts[2] = {nullptr};
+  BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
+  std::optional<unsigned> BinOpc;
+  Type *ExtOpTypes[2] = {nullptr};
+
+  auto CollectExtInfo = [&Exts,
+                         &ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
+    unsigned I = 0;
+    for (Value *OpI : Ops) {
+      Value *ExtOp;
+      if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
+        return false;
+      Exts[I] = cast<Instruction>(OpI);
+      ExtOpTypes[I] = ExtOp->getType();
+      I++;
+    }
+    return true;
+  };
+
+  if (ExtendUser) {
+    if (!ExtendUser->hasOneUse())
+      return false;
+
+    // Use the side-effect of match to replace BinOp only if the pattern is
+    // matched, we don't care at this point whether it actually matched.
+    match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
 
-  Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
-  Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+    SmallVector<Value *> Ops(ExtendUser->operands());
+    if (!CollectExtInfo(Ops))
+      return false;
+
+    BinOpc = std::make_optional(ExtendUser->getOpcode());
+  } else if (match(Update, m_Add(m_Value(), m_Value()))) {
+    // We already know the operands for Update are Op and PhiOp.
+    SmallVector<Value *> Ops({Op});
+    if (!CollectExtInfo(Ops))
+      return false;
+
+    ExtendUser = Update;
+    BinOpc = std::nullopt;
+  } else
+    return false;
 
   TTI::PartialReductionExtendKind OpAExtend =
-      TargetTransformInfo::getPartialReductionExtendKind(ExtA);
+      TargetTransformInfo::getPartialReductionExtendKind(Exts[0]);
   TTI::PartialReductionExtendKind OpBExtend =
-      TargetTransformInfo::getPartialReductionExtendKind(ExtB);
-
-  PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
+      Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
+              : TargetTransformInfo::PR_None;
+  PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
 
   unsigned TargetScaleFactor =
       PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
-          A->getType()->getPrimitiveSizeInBits());
+          ExtOpTypes[0]->getPrimitiveSizeInBits());
 
   if (LoopVectorizationPlanner::getDecisionAndClampRange(
           [&](ElementCount VF) {
             InstructionCost Cost = TTI->getPartialReductionCost(
-                Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
-                VF, OpAExtend, OpBExtend,
-                std::make_optional(BinOp->getOpcode()));
+                Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
+                PHI->getType(), VF, OpAExtend, OpBExtend, BinOpc);
             return Cost.isValid();
           },
           Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 334cfbad8bd7c..8d2d187231303 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -26,13 +26,14 @@ struct HistogramInfo;
 struct VFRange;
 
 /// A chain of instructions that form a partial reduction.
-/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
-/// accumulator).
+/// Designed to match either:
+///   reduction_bin_op (extend (A), accumulator), or
+///   reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
 struct PartialReductionChain {
   PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
-                        Instruction *ExtendB, Instruction *BinOp)
-      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
-  }
+                        Instruction *ExtendB, Instruction *ExtendUser)
+      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
+        ExtendUser(ExtendUser) {}
   /// The top-level binary operation that forms the reduction to a scalar
   /// after the loop body.
   Instruction *Reduction;
@@ -40,8 +41,8 @@ struct PartialReductionChain {
   Instruction *ExtendA;
   Instruction *ExtendB;
 
-  /// The binary operation using the extends that is then reduced.
-  Instruction *BinOp;
+  /// The user of the extend that is then reduced.
+  Instruction *ExtendUser;
 };
 
 /// Helper class to create VPRecipies from IR instructions.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b16a8fc563f4c..ad27e9435669f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -281,31 +281,18 @@ bool VPRecipeBase::isPhi() const {
 InstructionCost
 VPPartialReductionRecipe::computeCost(ElementCount VF,
                                       VPCostContext &Ctx) const {
-  std::optional<unsigned> Opcode = std::nullopt;
-  VPValue *BinOp = getOperand(0);
+  // If the input operand is an extend then use the opcode for this recipe.
+  std::optional<unsigned> Opcode;
+  VPValue *Op = getOperand(0);
+  VPRecipeBase *OpR = Op->getDefiningRecipe();
 
   // If the partial reduction is predicated, a select will be operand 0 rather
-  // than the binary op
+  // than the extend user.
   using namespace llvm::VPlanPatternMatch;
-  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
-    BinOp = BinOp->getDefiningRecipe()->getOperand(1);
-
-  // If BinOp is a negation, use the side effect of match to assign the actual
-  // binary operation to BinOp
-  match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
-  VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
-
-  if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
-    Opcode = std::make_optional(WidenR->getOpcode());
-
-  VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
-  VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
-
-  auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
-  auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
-                                                     : BinOpR->getOperand(0));
-  auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
-                                                     : BinOpR->getOperand(1));
+  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue()))) {
+    Op = OpR->getOperand(1);
+    OpR = Op->getDefiningRecipe();
+  }
 
   auto GetExtendKind = [](VPRecipeBase *R) {
     // The extend could come from outside the plan.
@@ -321,9 +308,38 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
     return TargetTransformInfo::PR_None;
   };
 
+  Type *InputTypeA, *InputTypeB;
+  TTI::PartialReductionExtendKind ExtAType, ExtBType;
+
+  // The input may come straight from a zext or sext.
+  if (isa<VPWidenCastRecipe>(OpR)) {
+    Opcode = std::nullopt;
+    InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0));
+    InputTypeB = nullptr;
+    ExtAType = GetExtendKind(OpR);
+    ExtBType = TargetTransformInfo::PR_None;
+  } else {
+    // If BinOp is a negation, use the side effect of match to assign the actual
+    // binary operation to BinOp
+    match(Op, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Op)));
+    OpR = Op->getDefiningRecipe();
+    Opcode = std::make_optional(cast<VPWidenRecipe>(OpR)->getOpcode());
+
+    VPRecipeBase *ExtAR = OpR->getOperand(0)->getDefiningRecipe();
+    VPRecipeBase *ExtBR = OpR->getOperand(1)->getDefiningRecipe();
+
+    InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
+                                                 : OpR->getOperand(0));
+    InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
+                                                 : OpR->getOperand(1));
+    ExtAType = GetExtendKind(ExtAR);
+    ExtBType = GetExtendKind(ExtBR);
+  }
+
+  auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
   return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
-                                         PhiType, VF, GetExtendKind(ExtAR),
-                                         GetExtendKind(ExtBR), Opcode);
+                                         PhiType, VF, ExtAType, ExtBType,
+                                         Opcode);
 }
 
 void VPPartialReductionRecipe::execute(VPTransformState &State) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index a229ca8c6e6db..075742ff95b04 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -1030,6 +1030,472 @@ for.body:                                         ; preds = %for.body.preheader,
   br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
 }
 
+
+define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-NEON-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-NEON-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEON-NEXT:  entry:
+; CHECK-NEON-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEON-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEON-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEON-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEON:       vector.ph:
+; CHECK-NEON-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEON-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-NEON:       vector.body:
+; CHECK-NEON-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP3:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP1]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
+; CHECK-NEON-NEXT:    [[TMP5:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
+; CHECK-NEON-NEXT:    [[TMP6:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP3]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
+; CHECK-NEON-NEXT:    [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP9]])
+; CHECK-NEON-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEON-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
+; CHECK-NEON:       middle.block:
+; CHECK-NEON-NEXT:    [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-NEON-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK-NEON:       scalar.ph:
+;
+; CHECK-SVE-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-SVE-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-SVE-NEXT:  entry:
+; CHECK-SVE-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-SVE-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-SVE-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-SVE-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-SVE:       vector.ph:
+; CHECK-SVE-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-SVE-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; CHECK-SVE-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-SVE:       vector.body:
+; CHECK-SVE-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP18:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2025

@llvm/pr-subscribers-llvm-analysis

Author: David Sherwood (david-arm)

Changes

Consider IR such as this:

for.body:
%iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
%accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
%gep.a = getelementptr i8, ptr %a, i64 %iv
%load.a = load i8, ptr %gep.a, align 1
%ext.a = zext i8 %load.a to i32
%add = add i32 %ext.a, %accum
%iv.next = add i64 %iv, 1
%exitcond.not = icmp eq i64 %iv.next, 1025
br i1 %exitcond.not, label %for.exit, label %for.body

Conceptually we can vectorise this using partial reductions too,
although the current loop vectoriser implementation requires the
accumulation of a multiply. For AArch64 this is easily done with
a udot or sdot with an identity operand, i.e. a vector of (i16 1).

In order to do this I had to teach getScaledReductions that the
accumulated value may come from a unary op, hence there is only
one extension to consider. Similarly, I updated the vplan and
AArch64 TTI cost model to understand the possible unary op.


Patch is 144.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133922.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+15-3)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+12-12)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+57-28)
  • (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+8-7)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+40-24)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll (+466)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (-506)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-neon.ll (+98)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce.ll (+826)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4835c66a7a3bc..5f3c8ff3bdfb4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1299,9 +1299,21 @@ class TargetTransformInfo {
   /// \return The cost of a partial reduction, which is a reduction from a
   /// vector to another vector with fewer elements of larger size. They are
   /// represented by the llvm.experimental.partial.reduce.add intrinsic, which
-  /// takes an accumulator and a binary operation operand that itself is fed by
-  /// two extends. An example of an operation that uses a partial reduction is a
-  /// dot product, which reduces two vectors to another of 4 times fewer and 4
+  /// takes an accumulator of type \p AccumType and a second vector operand to
+  /// be accumulated, whose element count is specified by \p VF. The type of
+  /// reduction is specified by \p Opcode. The second operand passed to the
+  /// intrinsic could be the result of an extend, such as sext or zext. In
+  /// this case \p BinOp is nullopt, \p InputTypeA represents the type being
+  /// extended and \p OpAExtend the operation, i.e. sign- or zero-extend.
+  /// Also, \p InputTypeB should be nullptr and OpBExtend should be None.
+  /// Alternatively, the second operand could be the result of a binary
+  /// operation performed on two extends, i.e.
+  ///   mul(zext i8 %a -> i32, zext i8 %b -> i32).
+  /// In this case \p BinOp may specify the opcode of the binary operation,
+  /// \p InputTypeA and \p InputTypeB the types being extended, and
+  /// \p OpAExtend, \p OpBExtend the form of extensions. An example of an
+  /// operation that uses a partial reduction is a dot product, which reduces
+  /// two vectors in binary mul operation to another of 4 times fewer and 4
   /// times larger elements.
   InstructionCost
   getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 77be41b78bc7f..48424185c68de 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5177,11 +5177,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
 
   // Sub opcodes currently only occur in chained cases.
   // Independent partial reduction subtractions are still costed as an add
-  if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
+  if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
+      OpAExtend == TTI::PR_None)
     return Invalid;
 
-  if (InputTypeA != InputTypeB)
+  // We only support multiply binary operations for now, and for muls we
+  // require the types being extended to be the same.
+  // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
+  // only if the i8mm or sve/streaming features are available.
+  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
+                OpBExtend == TTI::PR_None ||
+                (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+                 !ST->isSVEorStreamingSVEAvailable())))
     return Invalid;
+  assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
+         "Unexpected values for OpBExtend or InputTypeB");
 
   EVT InputEVT = EVT::getEVT(InputTypeA);
   EVT AccumEVT = EVT::getEVT(AccumType);
@@ -5228,16 +5238,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
   } else
     return Invalid;
 
-  // AArch64 supports lowering mixed extensions to a usdot but only if the
-  // i8mm or sve/streaming features are available.
-  if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
-      (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-       !ST->isSVEorStreamingSVEAvailable()))
-    return Invalid;
-
-  if (!BinOp || *BinOp != Instruction::Mul)
-    return Invalid;
-
   return Cost;
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0291a8bfd9674..654f3ecacf51b 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8765,15 +8765,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   // something that isn't another partial reduction. This is because the
   // extends are intended to be lowered along with the reduction itself.
 
-  // Build up a set of partial reduction bin ops for efficient use checking.
-  SmallSet<User *, 4> PartialReductionBinOps;
+  // Build up a set of partial reduction ops for efficient use checking.
+  SmallSet<User *, 4> PartialReductionOps;
   for (const auto &[PartialRdx, _] : PartialReductionChains)
-    PartialReductionBinOps.insert(PartialRdx.BinOp);
+    PartialReductionOps.insert(PartialRdx.ExtendUser);
 
   auto ExtendIsOnlyUsedByPartialReductions =
-      [&PartialReductionBinOps](Instruction *Extend) {
+      [&PartialReductionOps](Instruction *Extend) {
         return all_of(Extend->users(), [&](const User *U) {
-          return PartialReductionBinOps.contains(U);
+          return PartialReductionOps.contains(U);
         });
       };
 
@@ -8782,7 +8782,7 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   for (auto Pair : PartialReductionChains) {
     PartialReductionChain Chain = Pair.first;
     if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
-        ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+        (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
       ScaledReductionMap.insert(std::make_pair(Chain.Reduction, Pair.second));
   }
 }
@@ -8790,7 +8790,6 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
 bool VPRecipeBuilder::getScaledReductions(
     Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
     SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
-
   if (!CM.TheLoop->contains(RdxExitInstr))
     return false;
 
@@ -8819,40 +8818,70 @@ bool VPRecipeBuilder::getScaledReductions(
   if (PhiOp != PHI)
     return false;
 
-  auto *BinOp = dyn_cast<BinaryOperator>(Op);
-  if (!BinOp || !BinOp->hasOneUse())
-    return false;
-
   using namespace llvm::PatternMatch;
-  // Use the side-effect of match to replace BinOp only if the pattern is
-  // matched, we don't care at this point whether it actually matched.
-  match(BinOp, m_Neg(m_BinOp(BinOp)));
 
-  Value *A, *B;
-  if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
-      !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
-    return false;
+  // If the update is a binary operator, check both of its operands to see if
+  // they are extends. Otherwise, see if the update comes directly from an
+  // extend.
+  Instruction *Exts[2] = {nullptr};
+  BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
+  std::optional<unsigned> BinOpc;
+  Type *ExtOpTypes[2] = {nullptr};
+
+  auto CollectExtInfo = [&Exts,
+                         &ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
+    unsigned I = 0;
+    for (Value *OpI : Ops) {
+      Value *ExtOp;
+      if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
+        return false;
+      Exts[I] = cast<Instruction>(OpI);
+      ExtOpTypes[I] = ExtOp->getType();
+      I++;
+    }
+    return true;
+  };
+
+  if (ExtendUser) {
+    if (!ExtendUser->hasOneUse())
+      return false;
+
+    // Use the side-effect of match to replace BinOp only if the pattern is
+    // matched, we don't care at this point whether it actually matched.
+    match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
 
-  Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
-  Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+    SmallVector<Value *> Ops(ExtendUser->operands());
+    if (!CollectExtInfo(Ops))
+      return false;
+
+    BinOpc = std::make_optional(ExtendUser->getOpcode());
+  } else if (match(Update, m_Add(m_Value(), m_Value()))) {
+    // We already know the operands for Update are Op and PhiOp.
+    SmallVector<Value *> Ops({Op});
+    if (!CollectExtInfo(Ops))
+      return false;
+
+    ExtendUser = Update;
+    BinOpc = std::nullopt;
+  } else
+    return false;
 
   TTI::PartialReductionExtendKind OpAExtend =
-      TargetTransformInfo::getPartialReductionExtendKind(ExtA);
+      TargetTransformInfo::getPartialReductionExtendKind(Exts[0]);
   TTI::PartialReductionExtendKind OpBExtend =
-      TargetTransformInfo::getPartialReductionExtendKind(ExtB);
-
-  PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
+      Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
+              : TargetTransformInfo::PR_None;
+  PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
 
   unsigned TargetScaleFactor =
       PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
-          A->getType()->getPrimitiveSizeInBits());
+          ExtOpTypes[0]->getPrimitiveSizeInBits());
 
   if (LoopVectorizationPlanner::getDecisionAndClampRange(
           [&](ElementCount VF) {
             InstructionCost Cost = TTI->getPartialReductionCost(
-                Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
-                VF, OpAExtend, OpBExtend,
-                std::make_optional(BinOp->getOpcode()));
+                Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
+                PHI->getType(), VF, OpAExtend, OpBExtend, BinOpc);
             return Cost.isValid();
           },
           Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 334cfbad8bd7c..8d2d187231303 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -26,13 +26,14 @@ struct HistogramInfo;
 struct VFRange;
 
 /// A chain of instructions that form a partial reduction.
-/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
-/// accumulator).
+/// Designed to match either:
+///   reduction_bin_op (extend (A), accumulator), or
+///   reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
 struct PartialReductionChain {
   PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
-                        Instruction *ExtendB, Instruction *BinOp)
-      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
-  }
+                        Instruction *ExtendB, Instruction *ExtendUser)
+      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
+        ExtendUser(ExtendUser) {}
   /// The top-level binary operation that forms the reduction to a scalar
   /// after the loop body.
   Instruction *Reduction;
@@ -40,8 +41,8 @@ struct PartialReductionChain {
   Instruction *ExtendA;
   Instruction *ExtendB;
 
-  /// The binary operation using the extends that is then reduced.
-  Instruction *BinOp;
+  /// The user of the extend that is then reduced.
+  Instruction *ExtendUser;
 };
 
 /// Helper class to create VPRecipies from IR instructions.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b16a8fc563f4c..ad27e9435669f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -281,31 +281,18 @@ bool VPRecipeBase::isPhi() const {
 InstructionCost
 VPPartialReductionRecipe::computeCost(ElementCount VF,
                                       VPCostContext &Ctx) const {
-  std::optional<unsigned> Opcode = std::nullopt;
-  VPValue *BinOp = getOperand(0);
+  // If the input operand is an extend then use the opcode for this recipe.
+  std::optional<unsigned> Opcode;
+  VPValue *Op = getOperand(0);
+  VPRecipeBase *OpR = Op->getDefiningRecipe();
 
   // If the partial reduction is predicated, a select will be operand 0 rather
-  // than the binary op
+  // than the extend user.
   using namespace llvm::VPlanPatternMatch;
-  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
-    BinOp = BinOp->getDefiningRecipe()->getOperand(1);
-
-  // If BinOp is a negation, use the side effect of match to assign the actual
-  // binary operation to BinOp
-  match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
-  VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
-
-  if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
-    Opcode = std::make_optional(WidenR->getOpcode());
-
-  VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
-  VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
-
-  auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
-  auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
-                                                     : BinOpR->getOperand(0));
-  auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
-                                                     : BinOpR->getOperand(1));
+  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue()))) {
+    Op = OpR->getOperand(1);
+    OpR = Op->getDefiningRecipe();
+  }
 
   auto GetExtendKind = [](VPRecipeBase *R) {
     // The extend could come from outside the plan.
@@ -321,9 +308,38 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
     return TargetTransformInfo::PR_None;
   };
 
+  Type *InputTypeA, *InputTypeB;
+  TTI::PartialReductionExtendKind ExtAType, ExtBType;
+
+  // The input may come straight from a zext or sext.
+  if (isa<VPWidenCastRecipe>(OpR)) {
+    Opcode = std::nullopt;
+    InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0));
+    InputTypeB = nullptr;
+    ExtAType = GetExtendKind(OpR);
+    ExtBType = TargetTransformInfo::PR_None;
+  } else {
+    // If BinOp is a negation, use the side effect of match to assign the actual
+    // binary operation to BinOp
+    match(Op, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Op)));
+    OpR = Op->getDefiningRecipe();
+    Opcode = std::make_optional(cast<VPWidenRecipe>(OpR)->getOpcode());
+
+    VPRecipeBase *ExtAR = OpR->getOperand(0)->getDefiningRecipe();
+    VPRecipeBase *ExtBR = OpR->getOperand(1)->getDefiningRecipe();
+
+    InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
+                                                 : OpR->getOperand(0));
+    InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
+                                                 : OpR->getOperand(1));
+    ExtAType = GetExtendKind(ExtAR);
+    ExtBType = GetExtendKind(ExtBR);
+  }
+
+  auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
   return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
-                                         PhiType, VF, GetExtendKind(ExtAR),
-                                         GetExtendKind(ExtBR), Opcode);
+                                         PhiType, VF, ExtAType, ExtBType,
+                                         Opcode);
 }
 
 void VPPartialReductionRecipe::execute(VPTransformState &State) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index a229ca8c6e6db..075742ff95b04 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -1030,6 +1030,472 @@ for.body:                                         ; preds = %for.body.preheader,
   br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
 }
 
+
+define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-NEON-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-NEON-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEON-NEXT:  entry:
+; CHECK-NEON-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEON-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEON-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEON-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEON:       vector.ph:
+; CHECK-NEON-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEON-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-NEON:       vector.body:
+; CHECK-NEON-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP3:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP1]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
+; CHECK-NEON-NEXT:    [[TMP5:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
+; CHECK-NEON-NEXT:    [[TMP6:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP3]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
+; CHECK-NEON-NEXT:    [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP9]])
+; CHECK-NEON-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEON-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
+; CHECK-NEON:       middle.block:
+; CHECK-NEON-NEXT:    [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-NEON-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK-NEON:       scalar.ph:
+;
+; CHECK-SVE-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-SVE-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-SVE-NEXT:  entry:
+; CHECK-SVE-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-SVE-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-SVE-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-SVE-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-SVE:       vector.ph:
+; CHECK-SVE-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-SVE-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; CHECK-SVE-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-SVE:       vector.body:
+; CHECK-SVE-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP18:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]...
[truncated]

@david-arm
Copy link
Contributor Author

Gentle ping. :)

@david-arm david-arm requested a review from huntergr-arm April 17, 2025 10:50
Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious if/how this will interact with #136173?

Would it potentially help to simplify computeCost & co?

@SamTebbs33
Copy link
Collaborator

SamTebbs33 commented Apr 22, 2025

Curious if/how this will interact with #136173?

Would it potentially help to simplify computeCost & co?

I think we'll want to turn a partial reduction without a bin op into a VPExtendedReductionRecipe and then expand it back to a partial reduction.

@david-arm
Copy link
Contributor Author

Curious if/how this will interact with #136173?
Would it potentially help to simplify computeCost & co?

I think we'll want to turn a partial reduction without a bin op into a VPExtendedReductionRecipe and then expand it back to a partial reduction.

@fhahn @SamTebbs33 Do you think that #136173 is likely to land soon? Ideally I'd like to progress this patch within the next few weeks. It looks like #136173 also depends upon #113903, which looks like it might still take some time.

@david-arm
Copy link
Contributor Author

If both of those PRs are close to landing I could try to check them out and see how my PR interacts with them, but I'd rather do that once I know they're clos(ish) to the finish line.

@SamTebbs33
Copy link
Collaborator

Curious if/how this will interact with #136173?
Would it potentially help to simplify computeCost & co?

I think we'll want to turn a partial reduction without a bin op into a VPExtendedReductionRecipe and then expand it back to a partial reduction.

@fhahn @SamTebbs33 Do you think that #136173 is likely to land soon? Ideally I'd like to progress this patch within the next few weeks. It looks like #136173 also depends upon #113903, which looks like it might still take some time.

I think that 136173 is very very close to being approved and there shouldn't be any big changes to it on the way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants