diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp index 89cc7ea15ec1d..f7d8771648227 100644 --- a/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -434,6 +434,28 @@ class LazyValueInfoImpl { void solve(); + // For the following methods, if UseBlockValue is true, the function may + // push additional values to the worklist and return nullopt. If + // UseBlockValue is false, it will never return nullopt. + + std::optional + getValueFromSimpleICmpCondition(CmpInst::Predicate Pred, Value *RHS, + const APInt &Offset, Instruction *CxtI, + bool UseBlockValue); + + std::optional + getValueFromICmpCondition(Value *Val, ICmpInst *ICI, bool isTrueDest, + bool UseBlockValue); + + std::optional + getValueFromCondition(Value *Val, Value *Cond, bool IsTrueDest, + bool UseBlockValue, unsigned Depth = 0); + + std::optional getEdgeValueLocal(Value *Val, + BasicBlock *BBFrom, + BasicBlock *BBTo, + bool UseBlockValue); + public: /// This is the query interface to determine the lattice value for the /// specified Value* at the context instruction (if specified) or at the @@ -755,14 +777,10 @@ LazyValueInfoImpl::solveBlockValuePHINode(PHINode *PN, BasicBlock *BB) { return Result; } -static ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond, - bool isTrueDest = true, - unsigned Depth = 0); - // If we can determine a constraint on the value given conditions assumed by // the program, intersect those constraints with BBLV void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange( - Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) { + Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) { BBI = BBI ? BBI : dyn_cast(Val); if (!BBI) return; @@ -779,17 +797,21 @@ void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange( if (I->getParent() != BB || !isValidAssumeForContext(I, BBI)) continue; - BBLV = intersect(BBLV, getValueFromCondition(Val, I->getArgOperand(0))); + BBLV = intersect(BBLV, *getValueFromCondition(Val, I->getArgOperand(0), + /*IsTrueDest*/ true, + /*UseBlockValue*/ false)); } // If guards are not used in the module, don't spend time looking for them if (GuardDecl && !GuardDecl->use_empty() && BBI->getIterator() != BB->begin()) { - for (Instruction &I : make_range(std::next(BBI->getIterator().getReverse()), - BB->rend())) { + for (Instruction &I : + make_range(std::next(BBI->getIterator().getReverse()), BB->rend())) { Value *Cond = nullptr; if (match(&I, m_Intrinsic(m_Value(Cond)))) - BBLV = intersect(BBLV, getValueFromCondition(Val, Cond)); + BBLV = intersect(BBLV, + *getValueFromCondition(Val, Cond, /*IsTrueDest*/ true, + /*UseBlockValue*/ false)); } } @@ -886,10 +908,14 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) { // If the value is undef, a different value may be chosen in // the select condition. if (isGuaranteedNotToBeUndef(Cond, AC)) { - TrueVal = intersect(TrueVal, - getValueFromCondition(SI->getTrueValue(), Cond, true)); - FalseVal = intersect( - FalseVal, getValueFromCondition(SI->getFalseValue(), Cond, false)); + TrueVal = + intersect(TrueVal, *getValueFromCondition(SI->getTrueValue(), Cond, + /*IsTrueDest*/ true, + /*UseBlockValue*/ false)); + FalseVal = + intersect(FalseVal, *getValueFromCondition(SI->getFalseValue(), Cond, + /*IsTrueDest*/ false, + /*UseBlockValue*/ false)); } ValueLatticeElement Result = TrueVal; @@ -1068,15 +1094,26 @@ static bool matchICmpOperand(APInt &Offset, Value *LHS, Value *Val, } /// Get value range for a "(Val + Offset) Pred RHS" condition. -static ValueLatticeElement getValueFromSimpleICmpCondition( - CmpInst::Predicate Pred, Value *RHS, const APInt &Offset) { +std::optional +LazyValueInfoImpl::getValueFromSimpleICmpCondition(CmpInst::Predicate Pred, + Value *RHS, + const APInt &Offset, + Instruction *CxtI, + bool UseBlockValue) { ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(), /*isFullSet=*/true); - if (ConstantInt *CI = dyn_cast(RHS)) + if (ConstantInt *CI = dyn_cast(RHS)) { RHSRange = ConstantRange(CI->getValue()); - else if (Instruction *I = dyn_cast(RHS)) + } else if (UseBlockValue) { + std::optional R = + getBlockValue(RHS, CxtI->getParent(), CxtI); + if (!R) + return std::nullopt; + RHSRange = toConstantRange(*R, RHS->getType()); + } else if (Instruction *I = dyn_cast(RHS)) { if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) RHSRange = getConstantRangeFromMetadata(*Ranges); + } ConstantRange TrueValues = ConstantRange::makeAllowedICmpRegion(Pred, RHSRange); @@ -1103,8 +1140,8 @@ getRangeViaSLT(CmpInst::Predicate Pred, APInt RHS, return std::nullopt; } -static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI, - bool isTrueDest) { +std::optional LazyValueInfoImpl::getValueFromICmpCondition( + Value *Val, ICmpInst *ICI, bool isTrueDest, bool UseBlockValue) { Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); @@ -1128,11 +1165,13 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI, unsigned BitWidth = Ty->getScalarSizeInBits(); APInt Offset(BitWidth, 0); if (matchICmpOperand(Offset, LHS, Val, EdgePred)) - return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset); + return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI, + UseBlockValue); CmpInst::Predicate SwappedPred = CmpInst::getSwappedPredicate(EdgePred); if (matchICmpOperand(Offset, RHS, Val, SwappedPred)) - return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset); + return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset, ICI, + UseBlockValue); const APInt *Mask, *C; if (match(LHS, m_And(m_Specific(Val), m_APInt(Mask))) && @@ -1212,10 +1251,12 @@ static ValueLatticeElement getValueFromOverflowCondition( return ValueLatticeElement::getRange(NWR); } -static ValueLatticeElement getValueFromCondition( - Value *Val, Value *Cond, bool IsTrueDest, unsigned Depth) { +std::optional +LazyValueInfoImpl::getValueFromCondition(Value *Val, Value *Cond, + bool IsTrueDest, bool UseBlockValue, + unsigned Depth) { if (ICmpInst *ICI = dyn_cast(Cond)) - return getValueFromICmpCondition(Val, ICI, IsTrueDest); + return getValueFromICmpCondition(Val, ICI, IsTrueDest, UseBlockValue); if (auto *EVI = dyn_cast(Cond)) if (auto *WO = dyn_cast(EVI->getAggregateOperand())) @@ -1227,7 +1268,7 @@ static ValueLatticeElement getValueFromCondition( Value *N; if (match(Cond, m_Not(m_Value(N)))) - return getValueFromCondition(Val, N, !IsTrueDest, Depth); + return getValueFromCondition(Val, N, !IsTrueDest, UseBlockValue, Depth); Value *L, *R; bool IsAnd; @@ -1238,19 +1279,23 @@ static ValueLatticeElement getValueFromCondition( else return ValueLatticeElement::getOverdefined(); - ValueLatticeElement LV = getValueFromCondition(Val, L, IsTrueDest, Depth); - ValueLatticeElement RV = getValueFromCondition(Val, R, IsTrueDest, Depth); + std::optional LV = + getValueFromCondition(Val, L, IsTrueDest, UseBlockValue, Depth); + std::optional RV = + getValueFromCondition(Val, R, IsTrueDest, UseBlockValue, Depth); + if (!LV || !RV) + return std::nullopt; // if (L && R) -> intersect L and R // if (!(L || R)) -> intersect !L and !R // if (L || R) -> union L and R // if (!(L && R)) -> union !L and !R if (IsTrueDest ^ IsAnd) { - LV.mergeIn(RV); - return LV; + LV->mergeIn(*RV); + return *LV; } - return intersect(LV, RV); + return intersect(*LV, *RV); } // Return true if Usr has Op as an operand, otherwise false. @@ -1302,8 +1347,9 @@ static ValueLatticeElement constantFoldUser(User *Usr, Value *Op, } /// Compute the value of Val on the edge BBFrom -> BBTo. -static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, - BasicBlock *BBTo) { +std::optional +LazyValueInfoImpl::getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, + BasicBlock *BBTo, bool UseBlockValue) { // TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we // know that v != 0. if (BranchInst *BI = dyn_cast(BBFrom->getTerminator())) { @@ -1324,13 +1370,16 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, // If the condition of the branch is an equality comparison, we may be // able to infer the value. - ValueLatticeElement Result = getValueFromCondition(Val, Condition, - isTrueDest); - if (!Result.isOverdefined()) + std::optional Result = + getValueFromCondition(Val, Condition, isTrueDest, UseBlockValue); + if (!Result) + return std::nullopt; + + if (!Result->isOverdefined()) return Result; if (User *Usr = dyn_cast(Val)) { - assert(Result.isOverdefined() && "Result isn't overdefined"); + assert(Result->isOverdefined() && "Result isn't overdefined"); // Check with isOperationFoldable() first to avoid linearly iterating // over the operands unnecessarily which can be expensive for // instructions with many operands. @@ -1356,8 +1405,8 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, // br i1 %Condition, label %then, label %else for (unsigned i = 0; i < Usr->getNumOperands(); ++i) { Value *Op = Usr->getOperand(i); - ValueLatticeElement OpLatticeVal = - getValueFromCondition(Op, Condition, isTrueDest); + ValueLatticeElement OpLatticeVal = *getValueFromCondition( + Op, Condition, isTrueDest, /*UseBlockValue*/ false); if (std::optional OpConst = OpLatticeVal.asConstantInteger()) { Result = constantFoldUser(Usr, Op, *OpConst, DL); @@ -1367,7 +1416,7 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, } } } - if (!Result.isOverdefined()) + if (!Result->isOverdefined()) return Result; } } @@ -1432,8 +1481,12 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom, if (Constant *VC = dyn_cast(Val)) return ValueLatticeElement::get(VC); - ValueLatticeElement LocalResult = getEdgeValueLocal(Val, BBFrom, BBTo); - if (hasSingleValue(LocalResult)) + std::optional LocalResult = + getEdgeValueLocal(Val, BBFrom, BBTo, /*UseBlockValue*/ true); + if (!LocalResult) + return std::nullopt; + + if (hasSingleValue(*LocalResult)) // Can't get any more precise here return LocalResult; @@ -1453,7 +1506,7 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom, // but then the result is not cached. intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, CxtI); - return intersect(LocalResult, InBlock); + return intersect(*LocalResult, InBlock); } ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB, @@ -1499,10 +1552,12 @@ getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB, std::optional Result = getEdgeValue(V, FromBB, ToBB, CxtI); - if (!Result) { + while (!Result) { + // As the worklist only explicitly tracks block values (but not edge values) + // we may have to call solve() multiple times, as the edge value calculation + // may request additional block values. solve(); Result = getEdgeValue(V, FromBB, ToBB, CxtI); - assert(Result && "More work to do after problem solved?"); } LLVM_DEBUG(dbgs() << " Result = " << *Result << "\n"); @@ -1528,13 +1583,17 @@ ValueLatticeElement LazyValueInfoImpl::getValueAtUse(const Use &U) { if (!isGuaranteedNotToBeUndef(SI->getCondition(), AC)) break; if (CurrU->getOperandNo() == 1) - CondVal = getValueFromCondition(V, SI->getCondition(), true); + CondVal = + *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ true, + /*UseBlockValue*/ false); else if (CurrU->getOperandNo() == 2) - CondVal = getValueFromCondition(V, SI->getCondition(), false); + CondVal = + *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ false, + /*UseBlockValue*/ false); } else if (auto *PHI = dyn_cast(CurrI)) { // TODO: Use non-local query? - CondVal = - getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU), PHI->getParent()); + CondVal = *getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU), + PHI->getParent(), /*UseBlockValue*/ false); } if (CondVal) VL = intersect(VL, *CondVal); diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll index d30b31d317a6d..252f6596cedc5 100644 --- a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll @@ -12,8 +12,7 @@ define void @test_icmp_from_implied_cond(i32 %a, i32 %b) { ; CHECK-NEXT: [[COND:%.*]] = icmp ult i32 [[B]], [[A]] ; CHECK-NEXT: br i1 [[COND]], label [[L2:%.*]], label [[END]] ; CHECK: l2: -; CHECK-NEXT: [[B_CMP1:%.*]] = icmp ult i32 [[B]], 32 -; CHECK-NEXT: call void @use(i1 [[B_CMP1]]) +; CHECK-NEXT: call void @use(i1 true) ; CHECK-NEXT: [[B_CMP2:%.*]] = icmp ult i32 [[B]], 31 ; CHECK-NEXT: call void @use(i1 [[B_CMP2]]) ; CHECK-NEXT: ret void @@ -47,7 +46,7 @@ define i64 @test_sext_from_implied_cond(i32 %a, i32 %b) { ; CHECK-NEXT: [[COND:%.*]] = icmp ult i32 [[B]], [[A]] ; CHECK-NEXT: br i1 [[COND]], label [[L2:%.*]], label [[END]] ; CHECK: l2: -; CHECK-NEXT: [[SEXT:%.*]] = sext i32 [[B]] to i64 +; CHECK-NEXT: [[SEXT:%.*]] = zext nneg i32 [[B]] to i64 ; CHECK-NEXT: ret i64 [[SEXT]] ; CHECK: end: ; CHECK-NEXT: ret i64 0 @@ -74,8 +73,7 @@ define void @test_icmp_from_implied_range(i16 %x, i32 %b) { ; CHECK-NEXT: [[COND:%.*]] = icmp ult i32 [[B]], [[A]] ; CHECK-NEXT: br i1 [[COND]], label [[L1:%.*]], label [[END:%.*]] ; CHECK: l1: -; CHECK-NEXT: [[B_CMP1:%.*]] = icmp ult i32 [[B]], 65535 -; CHECK-NEXT: call void @use(i1 [[B_CMP1]]) +; CHECK-NEXT: call void @use(i1 true) ; CHECK-NEXT: [[B_CMP2:%.*]] = icmp ult i32 [[B]], 65534 ; CHECK-NEXT: call void @use(i1 [[B_CMP2]]) ; CHECK-NEXT: ret void