diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 1fc3c7b2236a1..3a58fff3be1d4 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -2193,6 +2193,18 @@ inline const APInt absdiff(const APInt &A, const APInt &B) { return A.uge(B) ? (A - B) : (B - A); } +/// Compute the floor of the signed average of C1 and C2 +APInt avgFloorS(const APInt &C1, const APInt &C2); + +/// Compute the floor of the unsigned average of C1 and C2 +APInt avgFloorU(const APInt &C1, const APInt &C2); + +/// Compute the ceil of the signed average of C1 and C2 +APInt avgCeilS(const APInt &C1, const APInt &C2); + +/// Compute the ceil of the unsigned average of C1 and C2 +APInt avgCeilU(const APInt &C1, const APInt &C2); + /// Compute GCD of two unsigned APInt values. /// /// This function returns the greatest common divisor of the two APInt values diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index f7ace79e8c51d..a844a00be1629 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6021,30 +6021,14 @@ static std::optional FoldValue(unsigned Opcode, const APInt &C1, APInt C2Ext = C2.zext(FullWidth); return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); } - case ISD::AVGFLOORS: { - unsigned FullWidth = C1.getBitWidth() + 1; - APInt C1Ext = C1.sext(FullWidth); - APInt C2Ext = C2.sext(FullWidth); - return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1); - } - case ISD::AVGFLOORU: { - unsigned FullWidth = C1.getBitWidth() + 1; - APInt C1Ext = C1.zext(FullWidth); - APInt C2Ext = C2.zext(FullWidth); - return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1); - } - case ISD::AVGCEILS: { - unsigned FullWidth = C1.getBitWidth() + 1; - APInt C1Ext = C1.sext(FullWidth); - APInt C2Ext = C2.sext(FullWidth); - return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1); - } - case ISD::AVGCEILU: { - unsigned FullWidth = C1.getBitWidth() + 1; - APInt C1Ext = C1.zext(FullWidth); - APInt C2Ext = C2.zext(FullWidth); - return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1); - } + case ISD::AVGFLOORS: + return APIntOps::avgFloorS(C1, C2); + case ISD::AVGFLOORU: + return APIntOps::avgFloorU(C1, C2); + case ISD::AVGCEILS: + return APIntOps::avgCeilS(C1, C2); + case ISD::AVGCEILU: + return APIntOps::avgCeilU(C1, C2); case ISD::ABDS: return APIntOps::smax(C1, C2) - APIntOps::smin(C1, C2); case ISD::ABDU: diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index e686b97652330..7053f3b87682f 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -3094,3 +3094,39 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src, memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes); } } + +APInt APIntOps::avgFloorS(const APInt &C1, const APInt &C2) { + // Return floor((C1 + C2)/2) + assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths"); + unsigned FullWidth = C1.getBitWidth() + 1; + APInt C1Ext = C1.sext(FullWidth); + APInt C2Ext = C2.sext(FullWidth); + return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1); +} + +APInt APIntOps::avgFloorU(const APInt &C1, const APInt &C2) { + // Return floor((C1 + C2)/2) + assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths"); + unsigned FullWidth = C1.getBitWidth() + 1; + APInt C1Ext = C1.zext(FullWidth); + APInt C2Ext = C2.zext(FullWidth); + return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1); +} + +APInt APIntOps::avgCeilS(const APInt &C1, const APInt &C2) { + // Return ceil((C1 + C2)/2) + assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths"); + unsigned FullWidth = C1.getBitWidth() + 1; + APInt C1Ext = C1.sext(FullWidth); + APInt C2Ext = C2.sext(FullWidth); + return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1); +} + +APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) { + // Return ceil((C1 + C2)/2) + assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths"); + unsigned FullWidth = C1.getBitWidth() + 1; + APInt C1Ext = C1.zext(FullWidth); + APInt C2Ext = C2.zext(FullWidth); + return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1); +} diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index 24324822356bf..3e49b5876e0b8 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -14,6 +14,7 @@ #include "llvm/Support/Alignment.h" #include "gtest/gtest.h" #include +#include #include using namespace llvm; @@ -2877,6 +2878,91 @@ TEST(APIntTest, RoundingSDiv) { } } +TEST(APIntTest, Average) { + APInt A0(32, 0); + APInt A2(32, 2); + APInt A100(32, 100); + APInt A101(32, 101); + APInt A200(32, 200, false); + APInt ApUMax(32, UINT_MAX, false); + + EXPECT_EQ(APInt(32, 150), APIntOps::avgFloorU(A100, A200)); + EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::DOWN), + APIntOps::avgFloorU(A100, A200)); + EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::UP), + APIntOps::avgCeilU(A100, A200)); + EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::DOWN), + APIntOps::avgFloorU(A100, A101)); + EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::UP), + APIntOps::avgCeilU(A100, A101)); + EXPECT_EQ(A0, APIntOps::avgFloorU(A0, A0)); + EXPECT_EQ(A0, APIntOps::avgCeilU(A0, A0)); + EXPECT_EQ(ApUMax, APIntOps::avgFloorU(ApUMax, ApUMax)); + EXPECT_EQ(ApUMax, APIntOps::avgCeilU(ApUMax, ApUMax)); + EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::DOWN), + APIntOps::avgFloorU(A0, ApUMax)); + EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::UP), + APIntOps::avgCeilU(A0, ApUMax)); + + APInt Ap100(32, +100); + APInt Ap101(32, +101); + APInt Ap200(32, +200); + APInt Am1(32, -1); + APInt Am100(32, -100); + APInt Am101(32, -101); + APInt Am200(32, -200); + APInt AmSMin(32, INT_MIN); + APInt ApSMax(32, INT_MAX); + + EXPECT_EQ(APInt(32, +150), APIntOps::avgFloorS(Ap100, Ap200)); + EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::DOWN), + APIntOps::avgFloorS(Ap100, Ap200)); + EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::UP), + APIntOps::avgCeilS(Ap100, Ap200)); + + EXPECT_EQ(APInt(32, -150), APIntOps::avgFloorS(Am100, Am200)); + EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::DOWN), + APIntOps::avgFloorS(Am100, Am200)); + EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::UP), + APIntOps::avgCeilS(Am100, Am200)); + + EXPECT_EQ(APInt(32, +100), APIntOps::avgFloorS(Ap100, Ap101)); + EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::DOWN), + APIntOps::avgFloorS(Ap100, Ap101)); + EXPECT_EQ(APInt(32, +101), APIntOps::avgCeilS(Ap100, Ap101)); + EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::UP), + APIntOps::avgCeilS(Ap100, Ap101)); + + EXPECT_EQ(APInt(32, -101), APIntOps::avgFloorS(Am100, Am101)); + EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::DOWN), + APIntOps::avgFloorS(Am100, Am101)); + EXPECT_EQ(APInt(32, -100), APIntOps::avgCeilS(Am100, Am101)); + EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::UP), + APIntOps::avgCeilS(Am100, Am101)); + + EXPECT_EQ(AmSMin, APIntOps::avgFloorS(AmSMin, AmSMin)); + EXPECT_EQ(AmSMin, APIntOps::avgCeilS(AmSMin, AmSMin)); + + EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::DOWN), + APIntOps::avgFloorS(A0, AmSMin)); + EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::UP), + APIntOps::avgCeilS(A0, AmSMin)); + + EXPECT_EQ(A0, APIntOps::avgFloorS(A0, A0)); + EXPECT_EQ(A0, APIntOps::avgCeilS(A0, A0)); + + EXPECT_EQ(Am1, APIntOps::avgFloorS(AmSMin, ApSMax)); + EXPECT_EQ(A0, APIntOps::avgCeilS(AmSMin, ApSMax)); + + EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::DOWN), + APIntOps::avgFloorS(A0, ApSMax)); + EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::UP), + APIntOps::avgCeilS(A0, ApSMax)); + + EXPECT_EQ(ApSMax, APIntOps::avgFloorS(ApSMax, ApSMax)); + EXPECT_EQ(ApSMax, APIntOps::avgCeilS(ApSMax, ApSMax)); +} + TEST(APIntTest, umul_ov) { const std::pair Overflows[] = { {0x8000000000000000, 2},