diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 2cac5557daeee..dbb39c8e14fbc 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -2443,15 +2443,23 @@ static VPRecipeBase *optimizeMaskToEVL(VPValue *HeaderMask, VPRecipeBase &CurRecipe, VPTypeAnalysis &TypeInfo, VPValue &AllOneMask, VPValue &EVL) { - // FIXME: Don't transform recipes to EVL recipes if they're not masked by the - // header mask. + // Derive a new mask by removing the header mask. Return nullptr if a new mask + // cannot be derived because the original mask does not contain the header + // mask. auto GetNewMask = [&](VPValue *OrigMask) -> VPValue * { assert(OrigMask && "Unmasked recipe when folding tail"); // HeaderMask will be handled using EVL. + if (HeaderMask == OrigMask) + return &AllOneMask; VPValue *Mask; if (match(OrigMask, m_LogicalAnd(m_Specific(HeaderMask), m_VPValue(Mask)))) return Mask; - return HeaderMask == OrigMask ? nullptr : OrigMask; + return nullptr; + }; + + // TODO: Can be simplified by SimplifyRecipes. + auto OptimizeAllTrueMask = [](VPValue *Mask) -> VPValue * { + return match(Mask, m_True()) ? nullptr : Mask; }; /// Adjust any end pointers so that they point to the end of EVL lanes not VF. @@ -2474,37 +2482,51 @@ static VPRecipeBase *optimizeMaskToEVL(VPValue *HeaderMask, }; return TypeSwitch(&CurRecipe) - .Case([&](VPWidenLoadRecipe *L) { + .Case([&](VPWidenLoadRecipe *L) -> VPRecipeBase * { VPValue *NewMask = GetNewMask(L->getMask()); + if (!NewMask) + return nullptr; VPValue *NewAddr = GetNewAddr(L->getAddr()); - return new VPWidenLoadEVLRecipe(*L, NewAddr, EVL, NewMask); + return new VPWidenLoadEVLRecipe(*L, NewAddr, EVL, + OptimizeAllTrueMask(NewMask)); }) - .Case([&](VPWidenStoreRecipe *S) { + .Case([&](VPWidenStoreRecipe *S) -> VPRecipeBase * { VPValue *NewMask = GetNewMask(S->getMask()); + if (!NewMask) + return nullptr; VPValue *NewAddr = GetNewAddr(S->getAddr()); - return new VPWidenStoreEVLRecipe(*S, NewAddr, EVL, NewMask); + return new VPWidenStoreEVLRecipe(*S, NewAddr, EVL, + OptimizeAllTrueMask(NewMask)); }) - .Case([&](VPInterleaveRecipe *IR) { + .Case([&](VPInterleaveRecipe *IR) -> VPRecipeBase * { VPValue *NewMask = GetNewMask(IR->getMask()); - return new VPInterleaveEVLRecipe(*IR, EVL, NewMask); + if (!NewMask) + return nullptr; + return new VPInterleaveEVLRecipe(*IR, EVL, + OptimizeAllTrueMask(NewMask)); }) - .Case([&](VPReductionRecipe *Red) { + .Case([&](VPReductionRecipe *Red) -> VPRecipeBase * { VPValue *NewMask = GetNewMask(Red->getCondOp()); - return new VPReductionEVLRecipe(*Red, EVL, NewMask); + if (!NewMask) + return nullptr; + return new VPReductionEVLRecipe(*Red, EVL, + OptimizeAllTrueMask(NewMask)); }) .Case([&](VPInstruction *VPI) -> VPRecipeBase * { - VPValue *LHS, *RHS; + VPValue *Cond, *LHS, *RHS; // Transform select with a header mask condition - // select(header_mask, LHS, RHS) + // select(mask_w/_header_mask, LHS, RHS) // into vector predication merge. - // vp.merge(all-true, LHS, RHS, EVL) - if (!match(VPI, m_Select(m_Specific(HeaderMask), m_VPValue(LHS), - m_VPValue(RHS)))) + // vp.merge(mask_w/o_header_mask, LHS, RHS, EVL) + if (!match(VPI, + m_Select(m_VPValue(Cond), m_VPValue(LHS), m_VPValue(RHS)))) + return nullptr; + + VPValue *NewMask = GetNewMask(Cond); + if (!NewMask) return nullptr; - // Use all true as the condition because this transformation is - // limited to selects whose condition is a header mask. return new VPWidenIntrinsicRecipe( - Intrinsic::vp_merge, {&AllOneMask, LHS, RHS, &EVL}, + Intrinsic::vp_merge, {NewMask, LHS, RHS, &EVL}, TypeInfo.inferScalarType(LHS), VPI->getDebugLoc()); }) .Default([&](VPRecipeBase *R) { return nullptr; }); diff --git a/llvm/test/Transforms/LoopVectorize/RISCV/divrem.ll b/llvm/test/Transforms/LoopVectorize/RISCV/divrem.ll index c35358b3eed0f..4e8577564246e 100644 --- a/llvm/test/Transforms/LoopVectorize/RISCV/divrem.ll +++ b/llvm/test/Transforms/LoopVectorize/RISCV/divrem.ll @@ -364,14 +364,9 @@ define void @predicated_udiv(ptr noalias nocapture %a, i64 %v, i64 %n) { ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[AVL:%.*]] = phi i64 [ 1024, [[VECTOR_PH]] ], [ [[AVL_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[TMP12:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[AVL]], i32 2, i1 true) -; CHECK-NEXT: [[BROADCAST_SPLATINSERT1:%.*]] = insertelement poison, i32 [[TMP12]], i64 0 -; CHECK-NEXT: [[BROADCAST_SPLAT2:%.*]] = shufflevector [[BROADCAST_SPLATINSERT1]], poison, zeroinitializer -; CHECK-NEXT: [[TMP7:%.*]] = call @llvm.stepvector.nxv2i32() -; CHECK-NEXT: [[TMP15:%.*]] = icmp ult [[TMP7]], [[BROADCAST_SPLAT2]] ; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[A:%.*]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = call @llvm.vp.load.nxv2i64.p0(ptr align 8 [[TMP8]], splat (i1 true), i32 [[TMP12]]) -; CHECK-NEXT: [[TMP16:%.*]] = select [[TMP15]], [[TMP6]], zeroinitializer -; CHECK-NEXT: [[TMP10:%.*]] = select [[TMP16]], [[BROADCAST_SPLAT]], splat (i64 1) +; CHECK-NEXT: [[TMP10:%.*]] = call @llvm.vp.merge.nxv2i64( [[TMP6]], [[BROADCAST_SPLAT]], splat (i64 1), i32 [[TMP12]]) ; CHECK-NEXT: [[TMP11:%.*]] = udiv [[WIDE_LOAD]], [[TMP10]] ; CHECK-NEXT: [[PREDPHI:%.*]] = select [[TMP6]], [[TMP11]], [[WIDE_LOAD]] ; CHECK-NEXT: call void @llvm.vp.store.nxv2i64.p0( [[PREDPHI]], ptr align 8 [[TMP8]], splat (i1 true), i32 [[TMP12]]) @@ -479,14 +474,9 @@ define void @predicated_sdiv(ptr noalias nocapture %a, i64 %v, i64 %n) { ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[AVL:%.*]] = phi i64 [ 1024, [[VECTOR_PH]] ], [ [[AVL_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[TMP12:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[AVL]], i32 2, i1 true) -; CHECK-NEXT: [[BROADCAST_SPLATINSERT1:%.*]] = insertelement poison, i32 [[TMP12]], i64 0 -; CHECK-NEXT: [[BROADCAST_SPLAT2:%.*]] = shufflevector [[BROADCAST_SPLATINSERT1]], poison, zeroinitializer -; CHECK-NEXT: [[TMP7:%.*]] = call @llvm.stepvector.nxv2i32() -; CHECK-NEXT: [[TMP15:%.*]] = icmp ult [[TMP7]], [[BROADCAST_SPLAT2]] ; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[A:%.*]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = call @llvm.vp.load.nxv2i64.p0(ptr align 8 [[TMP8]], splat (i1 true), i32 [[TMP12]]) -; CHECK-NEXT: [[TMP16:%.*]] = select [[TMP15]], [[TMP6]], zeroinitializer -; CHECK-NEXT: [[TMP10:%.*]] = select [[TMP16]], [[BROADCAST_SPLAT]], splat (i64 1) +; CHECK-NEXT: [[TMP10:%.*]] = call @llvm.vp.merge.nxv2i64( [[TMP6]], [[BROADCAST_SPLAT]], splat (i64 1), i32 [[TMP12]]) ; CHECK-NEXT: [[TMP11:%.*]] = sdiv [[WIDE_LOAD]], [[TMP10]] ; CHECK-NEXT: [[PREDPHI:%.*]] = select [[TMP6]], [[TMP11]], [[WIDE_LOAD]] ; CHECK-NEXT: call void @llvm.vp.store.nxv2i64.p0( [[PREDPHI]], ptr align 8 [[TMP8]], splat (i1 true), i32 [[TMP12]]) @@ -799,15 +789,10 @@ define void @predicated_sdiv_by_minus_one(ptr noalias nocapture %a, i64 %n) { ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[AVL:%.*]] = phi i64 [ 1024, [[VECTOR_PH]] ], [ [[AVL_NEXT:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[TMP12:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[AVL]], i32 16, i1 true) -; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement poison, i32 [[TMP12]], i64 0 -; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector [[BROADCAST_SPLATINSERT]], poison, zeroinitializer -; CHECK-NEXT: [[TMP6:%.*]] = call @llvm.stepvector.nxv16i32() -; CHECK-NEXT: [[TMP15:%.*]] = icmp ult [[TMP6]], [[BROADCAST_SPLAT]] ; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[A:%.*]], i64 [[INDEX]] ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = call @llvm.vp.load.nxv16i8.p0(ptr align 1 [[TMP7]], splat (i1 true), i32 [[TMP12]]) ; CHECK-NEXT: [[TMP9:%.*]] = icmp ne [[WIDE_LOAD]], splat (i8 -128) -; CHECK-NEXT: [[TMP16:%.*]] = select [[TMP15]], [[TMP9]], zeroinitializer -; CHECK-NEXT: [[TMP10:%.*]] = select [[TMP16]], splat (i8 -1), splat (i8 1) +; CHECK-NEXT: [[TMP10:%.*]] = call @llvm.vp.merge.nxv16i8( [[TMP9]], splat (i8 -1), splat (i8 1), i32 [[TMP12]]) ; CHECK-NEXT: [[TMP11:%.*]] = sdiv [[WIDE_LOAD]], [[TMP10]] ; CHECK-NEXT: [[PREDPHI:%.*]] = select [[TMP9]], [[TMP11]], [[WIDE_LOAD]] ; CHECK-NEXT: call void @llvm.vp.store.nxv16i8.p0( [[PREDPHI]], ptr align 1 [[TMP7]], splat (i1 true), i32 [[TMP12]])