@@ -2203,6 +2203,92 @@ static bool useActiveLaneMaskForControlFlow(TailFoldingStyle Style) {
2203
2203
Style == TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck;
2204
2204
}
2205
2205
2206
+ static void getPartialReductionInstrChain (Instruction *Instr, SmallVector<Value*, 4 > &Chain) {
2207
+ Instruction *Mul = cast<Instruction>(Instr->getOperand (0 ));
2208
+ Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand (0 ));
2209
+ Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand (1 ));
2210
+
2211
+ Chain.push_back (Mul);
2212
+ Chain.push_back (Ext0);
2213
+ Chain.push_back (Ext1);
2214
+ Chain.push_back (Instr->getOperand (1 ));
2215
+ }
2216
+
2217
+
2218
+ // / @param Instr The root instruction to scan
2219
+ static bool isInstrPartialReduction (Instruction *Instr) {
2220
+ Value *ExpectedPhi;
2221
+ Value *A, *B;
2222
+ Value *InductionA, *InductionB;
2223
+
2224
+ using namespace llvm ::PatternMatch;
2225
+ auto Pattern = m_Add (
2226
+ m_OneUse (m_Mul (
2227
+ m_OneUse (m_ZExt (
2228
+ m_OneUse (m_Load (
2229
+ m_GEP (
2230
+ m_Value (A),
2231
+ m_Value (InductionA)))))),
2232
+ m_OneUse (m_ZExt (
2233
+ m_OneUse (m_Load (
2234
+ m_GEP (
2235
+ m_Value (B),
2236
+ m_Value (InductionB))))))
2237
+ )), m_Value (ExpectedPhi));
2238
+
2239
+ bool Matches = match (Instr, Pattern);
2240
+
2241
+ if (!Matches)
2242
+ return false ;
2243
+
2244
+ // Check that the two induction variable uses are to the same induction variable
2245
+ if (InductionA != InductionB) {
2246
+ LLVM_DEBUG (dbgs () << " Loop uses different induction variables for each input variable, cannot create a partial reduction.\n " );
2247
+ return false ;
2248
+ }
2249
+
2250
+ Instruction *Mul = cast<Instruction>(Instr->getOperand (0 ));
2251
+ Instruction *Ext0 = cast<ZExtInst>(Mul->getOperand (0 ));
2252
+ Instruction *Ext1 = cast<ZExtInst>(Mul->getOperand (1 ));
2253
+
2254
+ // Check that the extends extend to i32
2255
+ if (!Ext0->getType ()->isIntegerTy (32 ) || !Ext1->getType ()->isIntegerTy (32 )) {
2256
+ LLVM_DEBUG (dbgs () << " Extends don't extend to the correct width, cannot create a partial reduction.\n " );
2257
+ return false ;
2258
+ }
2259
+
2260
+ // Check that the loads are loading i8
2261
+ LoadInst *Load0 = cast<LoadInst>(Ext0->getOperand (0 ));
2262
+ LoadInst *Load1 = cast<LoadInst>(Ext1->getOperand (0 ));
2263
+ if (!Load0->getType ()->isIntegerTy (8 ) || !Load1->getType ()->isIntegerTy (8 )) {
2264
+ LLVM_DEBUG (dbgs () << " Loads don't load the correct width, cannot create a partial reduction\n " );
2265
+ return false ;
2266
+ }
2267
+
2268
+ // Check that the add feeds into ExpectedPhi
2269
+ PHINode *PhiNode = dyn_cast<PHINode>(ExpectedPhi);
2270
+ if (!PhiNode) {
2271
+ LLVM_DEBUG (dbgs () << " Expected Phi node was not a phi, cannot create a partial reduction.\n " );
2272
+ return false ;
2273
+ }
2274
+
2275
+ // Check that the first phi value is a zero initializer
2276
+ ConstantInt *ZeroInit = dyn_cast<ConstantInt>(PhiNode->getIncomingValue (0 ));
2277
+ if (!ZeroInit || !ZeroInit->isZero ()) {
2278
+ LLVM_DEBUG (dbgs () << " First PHI value is not a constant zero, cannot create a partial reduction.\n " );
2279
+ return false ;
2280
+ }
2281
+
2282
+ // Check that the second phi value is the instruction we're looking at
2283
+ Instruction *MaybeAdd = dyn_cast<Instruction>(PhiNode->getIncomingValue (1 ));
2284
+ if (!MaybeAdd || MaybeAdd != Instr) {
2285
+ LLVM_DEBUG (dbgs () << " Second PHI value is not the root add, cannot create a partial reduction.\n " );
2286
+ return false ;
2287
+ }
2288
+
2289
+ return true ;
2290
+ }
2291
+
2206
2292
// Return true if \p OuterLp is an outer loop annotated with hints for explicit
2207
2293
// vectorization. The loop needs to be annotated with #pragma omp simd
2208
2294
// simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the
@@ -5084,6 +5170,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
5084
5170
return false ;
5085
5171
}
5086
5172
5173
+ // Prevent epilogue vectorization if a partial reduction is involved
5174
+ // TODO Is there a cleaner way to check this?
5175
+ if (any_of (Legal->getReductionVars (), [&](const std::pair<PHINode *, RecurrenceDescriptor> &Reduction) {
5176
+ return isInstrPartialReduction (Reduction.second .getLoopExitInstr ());
5177
+ }))
5178
+ return false ;
5179
+
5087
5180
// Epilogue vectorization code has not been auditted to ensure it handles
5088
5181
// non-latch exits properly. It may be fine, but it needs auditted and
5089
5182
// tested.
@@ -7182,6 +7275,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() {
7182
7275
const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts ();
7183
7276
VecValuesToIgnore.insert (Casts.begin (), Casts.end ());
7184
7277
}
7278
+
7279
+ // Ignore any values that we know will be flattened
7280
+ for (auto Reduction : this ->Legal ->getReductionVars ()) {
7281
+ auto &Recurrence = Reduction.second ;
7282
+ if (isInstrPartialReduction (Recurrence.getLoopExitInstr ())) {
7283
+ SmallVector<Value*, 4 > PartialReductionValues;
7284
+ getPartialReductionInstrChain (Recurrence.getLoopExitInstr (), PartialReductionValues);
7285
+ ValuesToIgnore.insert (PartialReductionValues.begin (), PartialReductionValues.end ());
7286
+ VecValuesToIgnore.insert (PartialReductionValues.begin (), PartialReductionValues.end ());
7287
+ }
7288
+ }
7185
7289
}
7186
7290
7187
7291
void LoopVectorizationCostModel::collectInLoopReductions () {
@@ -8536,9 +8640,24 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
8536
8640
*CI);
8537
8641
}
8538
8642
8643
+ if (auto *PartialReduce = tryToCreatePartialReduction (Range, Instr, Operands))
8644
+ return PartialReduce;
8645
+
8539
8646
return tryToWiden (Instr, Operands, VPBB);
8540
8647
}
8541
8648
8649
+ VPRecipeBase *VPRecipeBuilder::tryToCreatePartialReduction (
8650
+ VFRange &Range, Instruction *Instr, ArrayRef<VPValue *> Operands) {
8651
+
8652
+ if (isInstrPartialReduction (Instr)) {
8653
+ auto EC = ElementCount::getScalable (16 );
8654
+ if (std::find (Range.begin (), Range.end (), EC) == Range.end ())
8655
+ return nullptr ;
8656
+ return new VPPartialReductionRecipe (*Instr, make_range (Operands.begin (), Operands.end ()));
8657
+ }
8658
+ return nullptr ;
8659
+ }
8660
+
8542
8661
void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
8543
8662
ElementCount MaxVF) {
8544
8663
assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -8746,6 +8865,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
8746
8865
VPBB->appendRecipe (Recipe);
8747
8866
}
8748
8867
8868
+ for (auto &Recipe : *VPBB)
8869
+ Recipe.postInsertionOp ();
8870
+
8749
8871
VPBlockUtils::insertBlockAfter (new VPBasicBlock (), VPBB);
8750
8872
VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor ());
8751
8873
}
0 commit comments