diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h index 9b7f405b62564..ba4a5f01036ca 100644 --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -354,6 +354,18 @@ struct KnownBits { /// Compute knownbits resulting from llvm.usub.sat(LHS, RHS) static KnownBits usub_sat(const KnownBits &LHS, const KnownBits &RHS); + /// Compute knownbits resulting from APIntOps::avgFloorS + static KnownBits avgFloorS(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute knownbits resulting from APIntOps::avgFloorU + static KnownBits avgFloorU(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute knownbits resulting from APIntOps::avgCeilS + static KnownBits avgCeilS(const KnownBits &LHS, const KnownBits &RHS); + + /// Compute knownbits resulting from APIntOps::avgCeilU + static KnownBits avgCeilU(const KnownBits &LHS, const KnownBits &RHS); + /// Compute known bits resulting from multiplying LHS and RHS. static KnownBits mul(const KnownBits &LHS, const KnownBits &RHS, bool NoUndefSelfMultiply = false); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 0a258350c68a5..55aea08e16914 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3468,19 +3468,28 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts, Known = KnownBits::mulhs(Known, Known2); break; } - case ISD::AVGFLOORU: - case ISD::AVGCEILU: - case ISD::AVGFLOORS: + case ISD::AVGFLOORU: { + Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); + Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + Known = KnownBits::avgFloorU(Known, Known2); + break; + } + case ISD::AVGCEILU: { + Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); + Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + Known = KnownBits::avgCeilU(Known, Known2); + break; + } + case ISD::AVGFLOORS: { + Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); + Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + Known = KnownBits::avgFloorS(Known, Known2); + break; + } case ISD::AVGCEILS: { - bool IsCeil = Opcode == ISD::AVGCEILU || Opcode == ISD::AVGCEILS; - bool IsSigned = Opcode == ISD::AVGFLOORS || Opcode == ISD::AVGCEILS; Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); Known2 = computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); - Known = IsSigned ? Known.sext(BitWidth + 1) : Known.zext(BitWidth + 1); - Known2 = IsSigned ? Known2.sext(BitWidth + 1) : Known2.zext(BitWidth + 1); - KnownBits Carry = KnownBits::makeConstant(APInt(1, IsCeil ? 1 : 0)); - Known = KnownBits::computeForAddCarry(Known, Known2, Carry); - Known = Known.extractBits(BitWidth, 1); + Known = KnownBits::avgCeilS(Known, Known2); break; } case ISD::SELECT: diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp index fe47884f3e55a..d6012a8eea8a6 100644 --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -774,6 +774,37 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) { return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS); } +static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil, + bool IsSigned) { + unsigned BitWidth = LHS.getBitWidth(); + LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1); + RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1); + KnownBits Carry = KnownBits::makeConstant(APInt(1, IsCeil ? 1 : 0)); + LHS = KnownBits::computeForAddCarry(LHS, RHS, Carry); + LHS = LHS.extractBits(BitWidth, 1); + return LHS; +} + +KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) { + return avgCompute(LHS, RHS, /* IsCeil */ false, + /* IsSigned */ true); +} + +KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) { + return avgCompute(LHS, RHS, /* IsCeil */ false, + /* IsSigned */ false); +} + +KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) { + return avgCompute(LHS, RHS, /* IsCeil */ true, + /* IsSigned */ true); +} + +KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) { + return avgCompute(LHS, RHS, /* IsCeil */ true, + /* IsSigned */ false); +} + KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, bool NoUndefSelfMultiply) { unsigned BitWidth = LHS.getBitWidth(); diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index d740707027166..824cf7501fd44 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -501,6 +501,18 @@ TEST(KnownBitsTest, BinaryExhaustive) { "mulhu", KnownBits::mulhu, [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); }, /*CheckOptimality=*/false); + + testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS, + false); + + testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU, APIntOps::avgFloorU, + false); + + testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU, + false); + + testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS, + false); } TEST(KnownBitsTest, UnaryExhaustive) {