Skip to content

Commit 1640de0

Browse files
committed
[LoopVectorizer] Add support for partial reductions
1 parent 506c84a commit 1640de0

File tree

12 files changed

+387
-10
lines changed

12 files changed

+387
-10
lines changed

llvm/include/llvm/IR/DerivedTypes.h

+10
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,16 @@ class VectorType : public Type {
512512
EltCnt.divideCoefficientBy(2));
513513
}
514514

515+
/// This static method returns a VectorType with quarter as many elements as the
516+
/// input type and the same element type.
517+
static VectorType *getQuarterElementsVectorType(VectorType *VTy) {
518+
auto EltCnt = VTy->getElementCount();
519+
assert(EltCnt.isKnownEven() &&
520+
"Cannot halve vector with odd number of elements.");
521+
return VectorType::get(VTy->getElementType(),
522+
EltCnt.divideCoefficientBy(4));
523+
}
524+
515525
/// This static method returns a VectorType with twice as many elements as the
516526
/// input type and the same element type.
517527
static VectorType *getDoubleElementsVectorType(VectorType *VTy) {

llvm/include/llvm/IR/Intrinsics.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ namespace Intrinsic {
131131
ExtendArgument,
132132
TruncArgument,
133133
HalfVecArgument,
134+
QuarterVecArgument,
134135
SameVecWidthArgument,
135136
VecOfAnyPtrsToElt,
136137
VecElementArgument,
@@ -160,15 +161,15 @@ namespace Intrinsic {
160161

161162
unsigned getArgumentNumber() const {
162163
assert(Kind == Argument || Kind == ExtendArgument ||
163-
Kind == TruncArgument || Kind == HalfVecArgument ||
164+
Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
164165
Kind == SameVecWidthArgument || Kind == VecElementArgument ||
165166
Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
166167
Kind == VecOfBitcastsToInt);
167168
return Argument_Info >> 3;
168169
}
169170
ArgKind getArgumentKind() const {
170171
assert(Kind == Argument || Kind == ExtendArgument ||
171-
Kind == TruncArgument || Kind == HalfVecArgument ||
172+
Kind == TruncArgument || Kind == HalfVecArgument || Kind == QuarterVecArgument ||
172173
Kind == SameVecWidthArgument ||
173174
Kind == VecElementArgument || Kind == Subdivide2Argument ||
174175
Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);

llvm/include/llvm/IR/Intrinsics.td

+10
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def IIT_I4 : IIT_Int<4, 58>;
321321
def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
322322
def IIT_V6 : IIT_Vec<6, 60>;
323323
def IIT_V10 : IIT_Vec<10, 61>;
324+
def IIT_QUARTER_VEC_ARG : IIT_Base<62>;
324325
}
325326

326327
defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
@@ -457,6 +458,9 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
457458
class LLVMHalfElementsVectorType<int num>
458459
: LLVMMatchType<num, IIT_HALF_VEC_ARG>;
459460

461+
class LLVMQuarterElementsVectorType<int num>
462+
: LLVMMatchType<num, IIT_QUARTER_VEC_ARG>;
463+
460464
// Match the type of another intrinsic parameter that is expected to be a
461465
// vector type (i.e. <N x iM>) but with each element subdivided to
462466
// form a vector with more elements that are smaller than the original.
@@ -2605,6 +2609,12 @@ def int_experimental_vector_deinterleave2 : DefaultAttrsIntrinsic<[LLVMHalfEleme
26052609
[llvm_anyvector_ty],
26062610
[IntrNoMem]>;
26072611

2612+
//===-------------- Intrinsics to perform partial reduction ---------------===//
2613+
2614+
def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMQuarterElementsVectorType<0>],
2615+
[llvm_anyvector_ty],
2616+
[IntrNoMem]>;
2617+
26082618
//===----------------- Pointer Authentication Intrinsics ------------------===//
26092619
//
26102620

llvm/lib/IR/Function.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,12 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
12401240
ArgInfo));
12411241
return;
12421242
}
1243+
case IIT_QUARTER_VEC_ARG: {
1244+
unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
1245+
OutputTable.push_back(IITDescriptor::get(IITDescriptor::QuarterVecArgument,
1246+
ArgInfo));
1247+
return;
1248+
}
12431249
case IIT_SAME_VEC_WIDTH_ARG: {
12441250
unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]);
12451251
OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument,
@@ -1404,6 +1410,9 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
14041410
case IITDescriptor::HalfVecArgument:
14051411
return VectorType::getHalfElementsVectorType(cast<VectorType>(
14061412
Tys[D.getArgumentNumber()]));
1413+
case IITDescriptor::QuarterVecArgument: {
1414+
return VectorType::getQuarterElementsVectorType(cast<VectorType>(Tys[D.getArgumentNumber()]));
1415+
}
14071416
case IITDescriptor::SameVecWidthArgument: {
14081417
Type *EltTy = DecodeFixedType(Infos, Tys, Context);
14091418
Type *Ty = Tys[D.getArgumentNumber()];
@@ -1619,6 +1628,13 @@ static bool matchIntrinsicType(
16191628
return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
16201629
VectorType::getHalfElementsVectorType(
16211630
cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
1631+
case IITDescriptor::QuarterVecArgument: {
1632+
if (D.getArgumentNumber() >= ArgTys.size())
1633+
return IsDeferredCheck || DeferCheck(Ty);
1634+
return !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
1635+
VectorType::getQuarterElementsVectorType(
1636+
cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
1637+
}
16221638
case IITDescriptor::SameVecWidthArgument: {
16231639
if (D.getArgumentNumber() >= ArgTys.size()) {
16241640
// Defer check and subsequent check for the vector element type.

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

+122
Original file line numberDiff line numberDiff line change
@@ -2203,6 +2203,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
22032203
Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
22042204
}
22052205

2206+
static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
2207+
Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
2208+
Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
2209+
Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
2210+
2211+
Chain.push_back(Mul);
2212+
Chain.push_back(Ext0);
2213+
Chain.push_back(Ext1);
2214+
Chain.push_back(Instr->getOperand(1));
2215+
}
2216+
2217+
2218+
/// @param Instr The root instruction to scan
2219+
static bool isInstrPartialReduction(Instruction *Instr) {
2220+
Value *ExpectedPhi;
2221+
Value *A, *B;
2222+
Value *InductionA, *InductionB;
2223+
2224+
using namespace llvm::PatternMatch;
2225+
auto Pattern = m_Add(
2226+
m_OneUse(m_Mul(
2227+
m_OneUse(m_ZExt(
2228+
m_OneUse(m_Load(
2229+
m_GEP(
2230+
m_Value(A),
2231+
m_Value(InductionA)))))),
2232+
m_OneUse(m_ZExt(
2233+
m_OneUse(m_Load(
2234+
m_GEP(
2235+
m_Value(B),
2236+
m_Value(InductionB))))))
2237+
)), m_Value(ExpectedPhi));
2238+
2239+
bool Matches = match(Instr, Pattern);
2240+
2241+
if(!Matches)
2242+
return false;
2243+
2244+
// Check that the two induction variable uses are to the same induction variable
2245+
if(InductionA != InductionB) {
2246+
LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
2247+
return false;
2248+
}
2249+
2250+
Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
2251+
Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
2252+
Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
2253+
2254+
// Check that the extends extend to i32
2255+
if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
2256+
LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
2257+
return false;
2258+
}
2259+
2260+
// Check that the loads are loading i8
2261+
LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
2262+
LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
2263+
if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
2264+
LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
2265+
return false;
2266+
}
2267+
2268+
// Check that the add feeds into ExpectedPhi
2269+
PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
2270+
if(!PhiNode) {
2271+
LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
2272+
return false;
2273+
}
2274+
2275+
// Check that the first phi value is a zero initializer
2276+
ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
2277+
if(!ZeroInit || !ZeroInit->isZero()) {
2278+
LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
2279+
return false;
2280+
}
2281+
2282+
// Check that the second phi value is the instruction we're looking at
2283+
Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
2284+
if(!MaybeAdd || MaybeAdd != Instr) {
2285+
LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
2286+
return false;
2287+
}
2288+
2289+
return true;
2290+
}
2291+
22062292
// Return true if \p OuterLp is an outer loop annotated with hints for explicit
22072293
// vectorization. The loop needs to be annotated with #pragma omp simd
22082294
// simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5084,6 +5170,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
50845170
return false;
50855171
}
50865172

5173+
// Prevent epilogue vectorization if a partial reduction is involved
5174+
// TODO Is there a cleaner way to check this?
5175+
if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
5176+
return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
5177+
}))
5178+
return false;
5179+
50875180
// Epilogue vectorization code has not been auditted to ensure it handles
50885181
// non-latch exits properly. It may be fine, but it needs auditted and
50895182
// tested.
@@ -7182,6 +7275,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
71827275
const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
71837276
VecValuesToIgnore.insert(Casts.begin(), Casts.end());
71847277
}
7278+
7279+
// Ignore any values that we know will be flattened
7280+
for(auto Reduction : this->Legal->getReductionVars()) {
7281+
auto &Recurrence = Reduction.second;
7282+
if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
7283+
SmallVector<Value*, 4> PartialReductionValues;
7284+
getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
7285+
ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
7286+
VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
7287+
}
7288+
}
71857289
}
71867290

71877291
void LoopVectorizationCostModel::collectInLoopReductions() {
@@ -8536,9 +8640,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
85368640
*CI);
85378641
}
85388642

8643+
if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
8644+
return PartialReduce;
8645+
85398646
return tryToWiden(Instr, Operands, VPBB);
85408647
}
85418648

8649+
VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
8650+
VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
8651+
8652+
if(isInstrPartialReduction(Instr)) {
8653+
auto EC = ElementCount::getScalable(16);
8654+
if(std::find(Range.begin(), Range.end(), EC) == Range.end())
8655+
return nullptr;
8656+
return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
8657+
}
8658+
return nullptr;
8659+
}
8660+
85428661
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
85438662
ElementCount MaxVF) {
85448663
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8746,6 +8865,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
87468865
VPBB->appendRecipe(Recipe);
87478866
}
87488867

8868+
for(auto &Recipe : *VPBB)
8869+
Recipe.postInsertionOp();
8870+
87498871
VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
87508872
VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
87518873
}

llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h

+2
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ class VPRecipeBuilder {
116116
ArrayRef<VPValue *> Operands,
117117
VFRange &Range, VPBasicBlock *VPBB);
118118

119+
VPRecipeBase* tryToCreatePartialReduction(VFRange &Range, Instruction* Instr, ArrayRef<VPValue*> Operands);
120+
119121
/// Set the recipe created for given ingredient.
120122
void setRecipe(Instruction *I, VPRecipeBase *R) {
121123
assert(!Ingredient2Recipe.contains(I) &&

llvm/lib/Transforms/Vectorize/VPlan.h

+40-3
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,8 @@ class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock>,
767767
/// \returns an iterator pointing to the element after the erased one
768768
iplist<VPRecipeBase>::iterator eraseFromParent();
769769

770+
virtual void postInsertionOp() {}
771+
770772
/// Method to support type inquiry through isa, cast, and dyn_cast.
771773
static inline bool classof(const VPDef *D) {
772774
// All VPDefs are also VPRecipeBases.
@@ -1881,14 +1883,19 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
18811883
/// The phi is part of an ordered reduction. Requires IsInLoop to be true.
18821884
bool IsOrdered;
18831885

1886+
/// The amount that the VF should be divided by during ::execute
1887+
unsigned VFScaleFactor = 1;
1888+
18841889
public:
1890+
18851891
/// Create a new VPReductionPHIRecipe for the reduction \p Phi described by \p
18861892
/// RdxDesc.
18871893
VPReductionPHIRecipe(PHINode *Phi, const RecurrenceDescriptor &RdxDesc,
18881894
VPValue &Start, bool IsInLoop = false,
1889-
bool IsOrdered = false)
1895+
bool IsOrdered = false, unsigned VFScaleFactor = 1)
18901896
: VPHeaderPHIRecipe(VPDef::VPReductionPHISC, Phi, &Start),
1891-
RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered) {
1897+
RdxDesc(RdxDesc), IsInLoop(IsInLoop), IsOrdered(IsOrdered),
1898+
VFScaleFactor(VFScaleFactor) {
18921899
assert((!IsOrdered || IsInLoop) && "IsOrdered requires IsInLoop");
18931900
}
18941901

@@ -1897,7 +1904,7 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
18971904
VPReductionPHIRecipe *clone() override {
18981905
auto *R =
18991906
new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc,
1900-
*getOperand(0), IsInLoop, IsOrdered);
1907+
*getOperand(0), IsInLoop, IsOrdered, VFScaleFactor);
19011908
R->addOperand(getBackedgeValue());
19021909
return R;
19031910
}
@@ -1908,6 +1915,10 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
19081915
return R->getVPDefID() == VPDef::VPReductionPHISC;
19091916
}
19101917

1918+
void SetVFScaleFactor(unsigned ScaleFactor) {
1919+
VFScaleFactor = ScaleFactor;
1920+
}
1921+
19111922
/// Generate the phi/select nodes.
19121923
void execute(VPTransformState &State) override;
19131924

@@ -1928,6 +1939,32 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe {
19281939
bool isInLoop() const { return IsInLoop; }
19291940
};
19301941

1942+
class VPPartialReductionRecipe : public VPRecipeWithIRFlags {
1943+
unsigned Opcode;
1944+
public:
1945+
template <typename IterT>
1946+
VPPartialReductionRecipe(Instruction &I,
1947+
iterator_range<IterT> Operands) : VPRecipeWithIRFlags(
1948+
VPDef::VPPartialReductionSC, Operands, I), Opcode(I.getOpcode())
1949+
{}
1950+
~VPPartialReductionRecipe() override = default;
1951+
VPPartialReductionRecipe *clone() override {
1952+
auto *R = new VPPartialReductionRecipe(*getUnderlyingInstr(), operands());
1953+
R->transferFlags(*this);
1954+
return R;
1955+
}
1956+
VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
1957+
/// Generate the reduction in the loop
1958+
void execute(VPTransformState &State) override;
1959+
void postInsertionOp() override;
1960+
unsigned getOpcode() { return Opcode; }
1961+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1962+
/// Print the recipe.
1963+
void print(raw_ostream &O, const Twine &Indent,
1964+
VPSlotTracker &SlotTracker) const override;
1965+
#endif
1966+
};
1967+
19311968
/// A recipe for vectorizing a phi-node as a sequence of mask-based select
19321969
/// instructions.
19331970
class VPBlendRecipe : public VPSingleDefRecipe {

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
208208
llvm_unreachable("Unhandled opcode");
209209
}
210210

211+
Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPPartialReductionRecipe *R) {
212+
return R->getUnderlyingInstr()->getType();
213+
}
214+
211215
Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
212216
if (Type *CachedTy = CachedTypes.lookup(V))
213217
return CachedTy;
@@ -238,7 +242,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
238242
return inferScalarType(R->getOperand(0));
239243
})
240244
.Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
241-
VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>(
245+
VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe, VPPartialReductionRecipe>(
242246
[this](const auto *R) { return inferScalarTypeForRecipe(R); })
243247
.Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
244248
// TODO: Use info from interleave group.

llvm/lib/Transforms/Vectorize/VPlanAnalysis.h

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class VPWidenIntOrFpInductionRecipe;
2323
class VPWidenMemoryRecipe;
2424
struct VPWidenSelectRecipe;
2525
class VPReplicateRecipe;
26+
class VPPartialReductionRecipe;
2627
class Type;
2728

2829
/// An analysis for type-inference for VPValues.
@@ -49,6 +50,7 @@ class VPTypeAnalysis {
4950
Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R);
5051
Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R);
5152
Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
53+
Type *inferScalarTypeForRecipe(const VPPartialReductionRecipe *R);
5254

5355
public:
5456
VPTypeAnalysis(Type *CanonicalIVTy, LLVMContext &Ctx)

0 commit comments

Comments
 (0)