Skip to content

[InstCombine][InstSimplify] Pass SimplifyQuery to computeKnownBits directly. NFC. #74246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
if (IsNUW)
return Constant::getNullValue(Op0->getType());

KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
if (Known.Zero.isMaxSignedValue()) {
// Op1 is either 0 or the minimum signed value. If the sub is NSW, then
// Op1 must be 0 because negating the minimum signed value is undefined.
Expand Down Expand Up @@ -1063,7 +1063,7 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
// ("computeConstantRangeIncludingKnownBits")?
const APInt *C;
if (match(Y, m_APInt(C)) &&
computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT).getMaxValue().ult(*C))
computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C))
return true;

// Try again for any divisor:
Expand Down Expand Up @@ -1125,8 +1125,7 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
if (Op0 == Op1)
return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty);


KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
// X / 0 -> poison
// X % 0 -> poison
// If the divisor is known to be zero, just return poison. This can happen in
Expand Down Expand Up @@ -1195,7 +1194,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
// less trailing zeros, then the result must be poison.
const APInt *DivC;
if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countr_zero()) {
KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits KnownOp0 = computeKnownBits(Op0, /* Depth */ 0, Q);
if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero())
return PoisonValue::get(Op0->getType());
}
Expand Down Expand Up @@ -1355,7 +1354,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,

// If any bits in the shift amount make that value greater than or equal to
// the number of bits in the type, the shift is undefined.
KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits KnownAmt = computeKnownBits(Op1, /* Depth */ 0, Q);
if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth()))
return PoisonValue::get(Op0->getType());

Expand All @@ -1368,7 +1367,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
// Check for nsw shl leading to a poison value.
if (IsNSW) {
assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction");
KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits KnownVal = computeKnownBits(Op0, /* Depth */ 0, Q);
KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt);

if (KnownVal.Zero.isSignBitSet())
Expand Down Expand Up @@ -1404,8 +1403,7 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
// The low bit cannot be shifted out of an exact shift if it is set.
// TODO: Generalize by counting trailing zeros (see fold for exact division).
if (IsExact) {
KnownBits Op0Known =
computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
KnownBits Op0Known = computeKnownBits(Op0, /* Depth */ 0, Q);
if (Op0Known.One[0])
return Op0;
}
Expand Down Expand Up @@ -1477,7 +1475,7 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
if (Q.IIQ.UseInstrInfo && match(Op1, m_APInt(ShRAmt)) &&
match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) &&
*ShRAmt == *ShLAmt) {
const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
const unsigned EffWidthY = YKnown.countMaxActiveBits();
if (ShRAmt->uge(EffWidthY))
return X;
Expand Down Expand Up @@ -2105,7 +2103,7 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
match(Op0, m_Add(m_Value(Shift), m_AllOnes())) &&
isKnownToBeAPowerOfTwo(Shift, Q.DL, /*OrZero*/ false, 0, Q.AC, Q.CxtI,
Q.DT)) {
KnownBits Known = computeKnownBits(Shift, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits Known = computeKnownBits(Shift, /* Depth */ 0, Q);
// Use getActiveBits() to make use of the additional power of two knowledge
if (PowerC->getActiveBits() >= Known.getMaxValue().getActiveBits())
return ConstantInt::getNullValue(Op1->getType());
Expand Down Expand Up @@ -2169,10 +2167,10 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
m_Value(Y)))) {
const unsigned Width = Op0->getType()->getScalarSizeInBits();
const unsigned ShftCnt = ShAmt->getLimitedValue(Width);
const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
const unsigned EffWidthY = YKnown.countMaxActiveBits();
if (EffWidthY <= ShftCnt) {
const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
const unsigned EffWidthX = XKnown.countMaxActiveBits();
const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY);
const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt;
Expand Down Expand Up @@ -2968,15 +2966,15 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
return getTrue(ITy);
break;
case ICmpInst::ICMP_SLT: {
KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getTrue(ITy);
if (LHSKnown.isNonNegative())
return getFalse(ITy);
break;
}
case ICmpInst::ICMP_SLE: {
KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getTrue(ITy);
if (LHSKnown.isNonNegative() &&
Expand All @@ -2985,15 +2983,15 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
break;
}
case ICmpInst::ICMP_SGE: {
KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getFalse(ITy);
if (LHSKnown.isNonNegative())
return getTrue(ITy);
break;
}
case ICmpInst::ICMP_SGT: {
KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getFalse(ITy);
if (LHSKnown.isNonNegative() &&
Expand Down Expand Up @@ -3070,8 +3068,8 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
return getTrue(ITy);

if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q);
KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
if (RHSKnown.isNonNegative() && YKnown.isNegative())
return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy);
if (RHSKnown.isNegative() || YKnown.isNonNegative())
Expand All @@ -3094,7 +3092,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE: {
KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
if (!Known.isNonNegative())
break;
[[fallthrough]];
Expand All @@ -3105,7 +3103,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
return getFalse(ITy);
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE: {
KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
if (!Known.isNonNegative())
break;
[[fallthrough]];
Expand Down
6 changes: 2 additions & 4 deletions llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,15 +962,13 @@ static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
}

// Compute what we know about shift count.
KnownBits KnownCnt =
computeKnownBits(I.getOperand(1), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q);
unsigned BitWidth = KnownCnt.getBitWidth();
// Since shift produces a poison value if RHS is equal to or larger than the
// bit width, we can safely assume that RHS is less than the bit width.
uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1);

KnownBits KnownAmt =
computeKnownBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT);
KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q);
bool Changed = false;

if (I.getOpcode() == Instruction::Shl) {
Expand Down