@@ -8765,15 +8765,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8765
8765
// something that isn't another partial reduction. This is because the
8766
8766
// extends are intended to be lowered along with the reduction itself.
8767
8767
8768
- // Build up a set of partial reduction bin ops for efficient use checking.
8769
- SmallSet<User *, 4> PartialReductionBinOps ;
8768
+ // Build up a set of partial reduction ops for efficient use checking.
8769
+ SmallSet<User *, 4> PartialReductionOps ;
8770
8770
for (const auto &[PartialRdx, _] : PartialReductionChains)
8771
- PartialReductionBinOps .insert(PartialRdx.BinOp );
8771
+ PartialReductionOps .insert(PartialRdx.ExtendUser );
8772
8772
8773
8773
auto ExtendIsOnlyUsedByPartialReductions =
8774
- [&PartialReductionBinOps ](Instruction *Extend) {
8774
+ [&PartialReductionOps ](Instruction *Extend) {
8775
8775
return all_of(Extend->users(), [&](const User *U) {
8776
- return PartialReductionBinOps .contains(U);
8776
+ return PartialReductionOps .contains(U);
8777
8777
});
8778
8778
};
8779
8779
@@ -8782,15 +8782,14 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8782
8782
for (auto Pair : PartialReductionChains) {
8783
8783
PartialReductionChain Chain = Pair.first;
8784
8784
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8785
- ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
8785
+ (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB) ))
8786
8786
ScaledReductionMap.insert(std::make_pair(Chain.Reduction, Pair.second));
8787
8787
}
8788
8788
}
8789
8789
8790
8790
bool VPRecipeBuilder::getScaledReductions(
8791
8791
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8792
8792
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
8793
-
8794
8793
if (!CM.TheLoop->contains(RdxExitInstr))
8795
8794
return false;
8796
8795
@@ -8819,40 +8818,70 @@ bool VPRecipeBuilder::getScaledReductions(
8819
8818
if (PhiOp != PHI)
8820
8819
return false;
8821
8820
8822
- auto *BinOp = dyn_cast<BinaryOperator>(Op);
8823
- if (!BinOp || !BinOp->hasOneUse())
8824
- return false;
8825
-
8826
8821
using namespace llvm::PatternMatch;
8827
- // Use the side-effect of match to replace BinOp only if the pattern is
8828
- // matched, we don't care at this point whether it actually matched.
8829
- match(BinOp, m_Neg(m_BinOp(BinOp)));
8830
8822
8831
- Value *A, *B;
8832
- if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
8833
- !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
8834
- return false;
8823
+ // If the update is a binary operator, check both of its operands to see if
8824
+ // they are extends. Otherwise, see if the update comes directly from an
8825
+ // extend.
8826
+ Instruction *Exts[2] = {nullptr};
8827
+ BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
8828
+ std::optional<unsigned> BinOpc;
8829
+ Type *ExtOpTypes[2] = {nullptr};
8830
+
8831
+ auto CollectExtInfo = [&Exts,
8832
+ &ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
8833
+ unsigned I = 0;
8834
+ for (Value *OpI : Ops) {
8835
+ Value *ExtOp;
8836
+ if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
8837
+ return false;
8838
+ Exts[I] = cast<Instruction>(OpI);
8839
+ ExtOpTypes[I] = ExtOp->getType();
8840
+ I++;
8841
+ }
8842
+ return true;
8843
+ };
8844
+
8845
+ if (ExtendUser) {
8846
+ if (!ExtendUser->hasOneUse())
8847
+ return false;
8848
+
8849
+ // Use the side-effect of match to replace BinOp only if the pattern is
8850
+ // matched, we don't care at this point whether it actually matched.
8851
+ match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
8835
8852
8836
- Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
8837
- Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
8853
+ SmallVector<Value *> Ops(ExtendUser->operands());
8854
+ if (!CollectExtInfo(Ops))
8855
+ return false;
8856
+
8857
+ BinOpc = std::make_optional(ExtendUser->getOpcode());
8858
+ } else if (match(Update, m_Add(m_Value(), m_Value()))) {
8859
+ // We already know the operands for Update are Op and PhiOp.
8860
+ SmallVector<Value *> Ops({Op});
8861
+ if (!CollectExtInfo(Ops))
8862
+ return false;
8863
+
8864
+ ExtendUser = Update;
8865
+ BinOpc = std::nullopt;
8866
+ } else
8867
+ return false;
8838
8868
8839
8869
TTI::PartialReductionExtendKind OpAExtend =
8840
- TargetTransformInfo::getPartialReductionExtendKind(ExtA );
8870
+ TargetTransformInfo::getPartialReductionExtendKind(Exts[0] );
8841
8871
TTI::PartialReductionExtendKind OpBExtend =
8842
- TargetTransformInfo::getPartialReductionExtendKind(ExtB);
8843
-
8844
- PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp );
8872
+ Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
8873
+ : TargetTransformInfo::PR_None;
8874
+ PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser );
8845
8875
8846
8876
unsigned TargetScaleFactor =
8847
8877
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
8848
- A->getType() ->getPrimitiveSizeInBits());
8878
+ ExtOpTypes[0] ->getPrimitiveSizeInBits());
8849
8879
8850
8880
if (LoopVectorizationPlanner::getDecisionAndClampRange(
8851
8881
[&](ElementCount VF) {
8852
8882
InstructionCost Cost = TTI->getPartialReductionCost(
8853
- Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
8854
- VF, OpAExtend, OpBExtend,
8855
- std::make_optional(BinOp->getOpcode()));
8883
+ Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
8884
+ PHI->getType(), VF, OpAExtend, OpBExtend, BinOpc);
8856
8885
return Cost.isValid();
8857
8886
},
8858
8887
Range)) {
0 commit comments