diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 70f541d64b305..d17c64a778e86 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -278,6 +278,17 @@ class VPBuilder { new VPInstructionWithType(Opcode, Op, ResultTy, {}, DL)); } + VPValue *createScalarZExtOrTrunc(VPValue *Op, Type *ResultTy, Type *SrcTy, + DebugLoc DL) { + if (ResultTy == SrcTy) + return Op; + Instruction::CastOps CastOp = + ResultTy->getScalarSizeInBits() < SrcTy->getScalarSizeInBits() + ? Instruction::Trunc + : Instruction::ZExt; + return createScalarCast(CastOp, Op, ResultTy, DL); + } + VPWidenCastRecipe *createWidenCast(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy) { return tryInsertInstruction(new VPWidenCastRecipe(Opcode, Op, ResultTy)); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 11f0f2a930329..351f6ccebaeeb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -797,15 +797,8 @@ static VPValue *optimizeEarlyExitInductionUser(VPlan &Plan, VPValue *FirstActiveLane = B.createNaryOp(VPInstruction::FirstActiveLane, Mask, DL); Type *FirstActiveLaneType = TypeInfo.inferScalarType(FirstActiveLane); - if (CanonicalIVType != FirstActiveLaneType) { - Instruction::CastOps CastOp = - CanonicalIVType->getScalarSizeInBits() < - FirstActiveLaneType->getScalarSizeInBits() - ? Instruction::Trunc - : Instruction::ZExt; - FirstActiveLane = - B.createScalarCast(CastOp, FirstActiveLane, CanonicalIVType, DL); - } + FirstActiveLane = B.createScalarZExtOrTrunc(FirstActiveLane, CanonicalIVType, + FirstActiveLaneType, DL); EndValue = B.createNaryOp(Instruction::Add, {EndValue, FirstActiveLane}, DL); // `getOptimizableIVOf()` always returns the pre-incremented IV, so if it @@ -2182,13 +2175,10 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) { VPValue *MaxEVL = &Plan.getVF(); // Emit VPScalarCastRecipe in preheader if VF is not a 32 bits integer. VPBuilder Builder(LoopRegion->getPreheaderVPBB()); - if (unsigned VFSize = - TypeInfo.inferScalarType(MaxEVL)->getScalarSizeInBits(); - VFSize != 32) { - MaxEVL = Builder.createScalarCast( - VFSize > 32 ? Instruction::Trunc : Instruction::ZExt, MaxEVL, - Type::getInt32Ty(Ctx), DebugLoc()); - } + MaxEVL = Builder.createScalarZExtOrTrunc(MaxEVL, Type::getInt32Ty(Ctx), + TypeInfo.inferScalarType(MaxEVL), + DebugLoc()); + Builder.setInsertPoint(Header, Header->getFirstNonPhi()); PrevEVL = Builder.createScalarPhi({MaxEVL, &EVL}, DebugLoc(), "prev.evl"); } @@ -2286,6 +2276,7 @@ bool VPlanTransforms::tryAddExplicitVectorLength( return false; auto *CanonicalIVPHI = Plan.getCanonicalIV(); + auto *CanIVTy = CanonicalIVPHI->getScalarType(); VPValue *StartV = CanonicalIVPHI->getStartValue(); // Create the ExplicitVectorLengthPhi recipe in the main loop. @@ -2297,8 +2288,8 @@ bool VPlanTransforms::tryAddExplicitVectorLength( Instruction::Sub, {Plan.getTripCount(), EVLPhi}, DebugLoc(), "avl"); if (MaxSafeElements) { // Support for MaxSafeDist for correct loop emission. - VPValue *AVLSafe = Plan.getOrAddLiveIn( - ConstantInt::get(CanonicalIVPHI->getScalarType(), *MaxSafeElements)); + VPValue *AVLSafe = + Plan.getOrAddLiveIn(ConstantInt::get(CanIVTy, *MaxSafeElements)); VPValue *Cmp = Builder.createICmp(ICmpInst::ICMP_ULT, AVL, AVLSafe); AVL = Builder.createSelect(Cmp, AVL, AVLSafe, DebugLoc(), "safe_avl"); } @@ -2308,13 +2299,12 @@ bool VPlanTransforms::tryAddExplicitVectorLength( auto *CanonicalIVIncrement = cast(CanonicalIVPHI->getBackedgeValue()); Builder.setInsertPoint(CanonicalIVIncrement); - VPSingleDefRecipe *OpVPEVL = VPEVL; - if (unsigned IVSize = CanonicalIVPHI->getScalarType()->getScalarSizeInBits(); - IVSize != 32) { - OpVPEVL = Builder.createScalarCast( - IVSize < 32 ? Instruction::Trunc : Instruction::ZExt, OpVPEVL, - CanonicalIVPHI->getScalarType(), CanonicalIVIncrement->getDebugLoc()); - } + VPValue *OpVPEVL = VPEVL; + + auto *I32Ty = Type::getInt32Ty(CanIVTy->getContext()); + OpVPEVL = Builder.createScalarZExtOrTrunc( + OpVPEVL, CanIVTy, I32Ty, CanonicalIVIncrement->getDebugLoc()); + auto *NextEVLIV = Builder.createOverflowingOp( Instruction::Add, {OpVPEVL, EVLPhi}, {CanonicalIVIncrement->hasNoUnsignedWrap(),