@@ -2203,6 +2203,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
2203
2203
Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
2204
2204
}
2205
2205
2206
+ static void getPartialReductionInstrChain(Instruction *Instr, SmallVector<Value*, 4> &Chain) {
2207
+ Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
2208
+ Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
2209
+ Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
2210
+
2211
+ Chain.push_back(Mul);
2212
+ Chain.push_back(Ext0);
2213
+ Chain.push_back(Ext1);
2214
+ Chain.push_back(Instr->getOperand(1));
2215
+ }
2216
+
2217
+
2218
+ /// @param Instr The root instruction to scan
2219
+ static bool isInstrPartialReduction(Instruction *Instr) {
2220
+ Value *ExpectedPhi;
2221
+ Value *A, *B;
2222
+ Value *InductionA, *InductionB;
2223
+
2224
+ using namespace llvm::PatternMatch;
2225
+ auto Pattern = m_Add(
2226
+ m_OneUse(m_Mul(
2227
+ m_OneUse(m_ZExt(
2228
+ m_OneUse(m_Load(
2229
+ m_GEP(
2230
+ m_Value(A),
2231
+ m_Value(InductionA)))))),
2232
+ m_OneUse(m_ZExt(
2233
+ m_OneUse(m_Load(
2234
+ m_GEP(
2235
+ m_Value(B),
2236
+ m_Value(InductionB))))))
2237
+ )), m_Value(ExpectedPhi));
2238
+
2239
+ bool Matches = match(Instr, Pattern);
2240
+
2241
+ if(!Matches)
2242
+ return false;
2243
+
2244
+ // Check that the two induction variable uses are to the same induction variable
2245
+ if(InductionA != InductionB) {
2246
+ LLVM_DEBUG(dbgs() << "Loop uses different induction variables for each input variable, cannot create a partial reduction.\n");
2247
+ return false;
2248
+ }
2249
+
2250
+ Instruction *Mul = cast<Instruction>(Instr->getOperand(0));
2251
+ Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand(0));
2252
+ Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand(1));
2253
+
2254
+ // Check that the extends extend to i32
2255
+ if(!Ext0->getType()->isIntegerTy(32) || !Ext1->getType()->isIntegerTy(32)) {
2256
+ LLVM_DEBUG(dbgs() << "Extends don't extend to the correct width, cannot create a partial reduction.\n");
2257
+ return false;
2258
+ }
2259
+
2260
+ // Check that the loads are loading i8
2261
+ LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand(0));
2262
+ LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand(0));
2263
+ if(!Load0->getType()->isIntegerTy(8) || !Load1->getType()->isIntegerTy(8)) {
2264
+ LLVM_DEBUG(dbgs() << "Loads don't load the correct width, cannot create a partial reduction\n");
2265
+ return false;
2266
+ }
2267
+
2268
+ // Check that the add feeds into ExpectedPhi
2269
+ PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
2270
+ if(!PhiNode) {
2271
+ LLVM_DEBUG(dbgs() << "Expected Phi node was not a phi, cannot create a partial reduction.\n");
2272
+ return false;
2273
+ }
2274
+
2275
+ // Check that the first phi value is a zero initializer
2276
+ ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue(0));
2277
+ if(!ZeroInit || !ZeroInit->isZero()) {
2278
+ LLVM_DEBUG(dbgs() << "First PHI value is not a constant zero, cannot create a partial reduction.\n");
2279
+ return false;
2280
+ }
2281
+
2282
+ // Check that the second phi value is the instruction we're looking at
2283
+ Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue(1));
2284
+ if(!MaybeAdd || MaybeAdd != Instr) {
2285
+ LLVM_DEBUG(dbgs() << "Second PHI value is not the root add, cannot create a partial reduction.\n");
2286
+ return false;
2287
+ }
2288
+
2289
+ return true;
2290
+ }
2291
+
2206
2292
// Return true if \p OuterLp is an outer loop annotated with hints for explicit
2207
2293
// vectorization. The loop needs to be annotated with #pragma omp simd
2208
2294
// simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5084,6 +5170,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
5084
5170
return false;
5085
5171
}
5086
5172
5173
+ // Prevent epilogue vectorization if a partial reduction is involved
5174
+ // TODO Is there a cleaner way to check this?
5175
+ if(any_of(Legal->getReductionVars(), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
5176
+ return isInstrPartialReduction(Reduction.second.getLoopExitInstr());
5177
+ }))
5178
+ return false;
5179
+
5087
5180
// Epilogue vectorization code has not been auditted to ensure it handles
5088
5181
// non-latch exits properly. It may be fine, but it needs auditted and
5089
5182
// tested.
@@ -7182,6 +7275,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
7182
7275
const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
7183
7276
VecValuesToIgnore.insert(Casts.begin(), Casts.end());
7184
7277
}
7278
+
7279
+ // Ignore any values that we know will be flattened
7280
+ for(auto Reduction : this->Legal->getReductionVars()) {
7281
+ auto &Recurrence = Reduction.second;
7282
+ if(isInstrPartialReduction(Recurrence.getLoopExitInstr())) {
7283
+ SmallVector<Value*, 4> PartialReductionValues;
7284
+ getPartialReductionInstrChain(Recurrence.getLoopExitInstr(), PartialReductionValues);
7285
+ ValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
7286
+ VecValuesToIgnore.insert(PartialReductionValues.begin(), PartialReductionValues.end());
7287
+ }
7288
+ }
7185
7289
}
7186
7290
7187
7291
void LoopVectorizationCostModel::collectInLoopReductions() {
@@ -8536,9 +8640,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8536
8640
*CI);
8537
8641
}
8538
8642
8643
+ if(auto *PartialReduce = tryToCreatePartialReduction(Range, Instr, Operands))
8644
+ return PartialReduce;
8645
+
8539
8646
return tryToWiden(Instr, Operands, VPBB);
8540
8647
}
8541
8648
8649
+ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction(
8650
+ VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
8651
+
8652
+ if(isInstrPartialReduction(Instr)) {
8653
+ auto EC = ElementCount::getScalable(16);
8654
+ if(std::find(Range.begin(), Range.end(), EC) == Range.end())
8655
+ return nullptr;
8656
+ return new VPPartialReductionRecipe(*Instr, make_range(Operands.begin(), Operands.end()));
8657
+ }
8658
+ return nullptr;
8659
+ }
8660
+
8542
8661
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
8543
8662
ElementCount MaxVF) {
8544
8663
assert(OrigLoop->isInnermost() && "Inner loop expected.");
@@ -8746,6 +8865,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
8746
8865
VPBB->appendRecipe(Recipe);
8747
8866
}
8748
8867
8868
+ for(auto &Recipe : *VPBB)
8869
+ Recipe.postInsertionOp();
8870
+
8749
8871
VPBlockUtils::insertBlockAfter(new VPBasicBlock(), VPBB);
8750
8872
VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor());
8751
8873
}
0 commit comments