Skip to content

[ADT] Add implementations for avgFloor and avgCeil to APInt #84431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llvm/include/llvm/ADT/APInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,18 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
return A.uge(B) ? (A - B) : (B - A);
}

/// Compute the floor of the signed average of C1 and C2
APInt avgFloorS(const APInt &C1, const APInt &C2);

/// Compute the floor of the unsigned average of C1 and C2
APInt avgFloorU(const APInt &C1, const APInt &C2);

/// Compute the ceil of the signed average of C1 and C2
APInt avgCeilS(const APInt &C1, const APInt &C2);

/// Compute the ceil of the unsigned average of C1 and C2
APInt avgCeilU(const APInt &C1, const APInt &C2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does anyone have any preference for camelcase vs lowercase?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope :)


/// Compute GCD of two unsigned APInt values.
///
/// This function returns the greatest common divisor of the two APInt values
Expand Down
32 changes: 8 additions & 24 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6021,30 +6021,14 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
APInt C2Ext = C2.sext(FullWidth);
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
}
case ISD::AVGFLOORU: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.zext(FullWidth);
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
}
case ISD::AVGCEILS: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
APInt C2Ext = C2.sext(FullWidth);
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
}
case ISD::AVGCEILU: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.zext(FullWidth);
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
}
case ISD::AVGFLOORS:
return APIntOps::avgFloorS(C1, C2);
case ISD::AVGFLOORU:
return APIntOps::avgFloorU(C1, C2);
case ISD::AVGCEILS:
return APIntOps::avgCeilS(C1, C2);
case ISD::AVGCEILU:
return APIntOps::avgCeilU(C1, C2);
case ISD::ABDS:
return APIntOps::smax(C1, C2) - APIntOps::smin(C1, C2);
case ISD::ABDU:
Expand Down
36 changes: 36 additions & 0 deletions llvm/lib/Support/APInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3094,3 +3094,39 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
}
}

APInt APIntOps::avgFloorS(const APInt &C1, const APInt &C2) {
// Return floor((C1 + C2)/2)
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
APInt C2Ext = C2.sext(FullWidth);
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
}

APInt APIntOps::avgFloorU(const APInt &C1, const APInt &C2) {
// Return floor((C1 + C2)/2)
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.zext(FullWidth);
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
}

APInt APIntOps::avgCeilS(const APInt &C1, const APInt &C2) {
// Return ceil((C1 + C2)/2)
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
APInt C2Ext = C2.sext(FullWidth);
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
}

APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
// Return ceil((C1 + C2)/2)
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.zext(FullWidth);
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
}
86 changes: 86 additions & 0 deletions llvm/unittests/ADT/APIntTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "llvm/Support/Alignment.h"
#include "gtest/gtest.h"
#include <array>
#include <climits>
#include <optional>

using namespace llvm;
Expand Down Expand Up @@ -2877,6 +2878,91 @@ TEST(APIntTest, RoundingSDiv) {
}
}

TEST(APIntTest, Average) {
APInt A0(32, 0);
APInt A2(32, 2);
APInt A100(32, 100);
APInt A101(32, 101);
APInt A200(32, 200, false);
APInt ApUMax(32, UINT_MAX, false);

EXPECT_EQ(APInt(32, 150), APIntOps::avgFloorU(A100, A200));
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::DOWN),
APIntOps::avgFloorU(A100, A200));
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::UP),
APIntOps::avgCeilU(A100, A200));
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::DOWN),
APIntOps::avgFloorU(A100, A101));
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::UP),
APIntOps::avgCeilU(A100, A101));
EXPECT_EQ(A0, APIntOps::avgFloorU(A0, A0));
EXPECT_EQ(A0, APIntOps::avgCeilU(A0, A0));
EXPECT_EQ(ApUMax, APIntOps::avgFloorU(ApUMax, ApUMax));
EXPECT_EQ(ApUMax, APIntOps::avgCeilU(ApUMax, ApUMax));
EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::DOWN),
APIntOps::avgFloorU(A0, ApUMax));
EXPECT_EQ(APIntOps::RoundingUDiv(ApUMax, A2, APInt::Rounding::UP),
APIntOps::avgCeilU(A0, ApUMax));

APInt Ap100(32, +100);
APInt Ap101(32, +101);
APInt Ap200(32, +200);
APInt Am1(32, -1);
APInt Am100(32, -100);
APInt Am101(32, -101);
APInt Am200(32, -200);
APInt AmSMin(32, INT_MIN);
APInt ApSMax(32, INT_MAX);

EXPECT_EQ(APInt(32, +150), APIntOps::avgFloorS(Ap100, Ap200));
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::DOWN),
APIntOps::avgFloorS(Ap100, Ap200));
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::UP),
APIntOps::avgCeilS(Ap100, Ap200));

EXPECT_EQ(APInt(32, -150), APIntOps::avgFloorS(Am100, Am200));
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::DOWN),
APIntOps::avgFloorS(Am100, Am200));
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::UP),
APIntOps::avgCeilS(Am100, Am200));

EXPECT_EQ(APInt(32, +100), APIntOps::avgFloorS(Ap100, Ap101));
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::DOWN),
APIntOps::avgFloorS(Ap100, Ap101));
EXPECT_EQ(APInt(32, +101), APIntOps::avgCeilS(Ap100, Ap101));
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::UP),
APIntOps::avgCeilS(Ap100, Ap101));

EXPECT_EQ(APInt(32, -101), APIntOps::avgFloorS(Am100, Am101));
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::DOWN),
APIntOps::avgFloorS(Am100, Am101));
EXPECT_EQ(APInt(32, -100), APIntOps::avgCeilS(Am100, Am101));
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::UP),
APIntOps::avgCeilS(Am100, Am101));

EXPECT_EQ(AmSMin, APIntOps::avgFloorS(AmSMin, AmSMin));
EXPECT_EQ(AmSMin, APIntOps::avgCeilS(AmSMin, AmSMin));

EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::DOWN),
APIntOps::avgFloorS(A0, AmSMin));
EXPECT_EQ(APIntOps::RoundingSDiv(AmSMin, A2, APInt::Rounding::UP),
APIntOps::avgCeilS(A0, AmSMin));

EXPECT_EQ(A0, APIntOps::avgFloorS(A0, A0));
EXPECT_EQ(A0, APIntOps::avgCeilS(A0, A0));

EXPECT_EQ(Am1, APIntOps::avgFloorS(AmSMin, ApSMax));
EXPECT_EQ(A0, APIntOps::avgCeilS(AmSMin, ApSMax));

EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::DOWN),
APIntOps::avgFloorS(A0, ApSMax));
EXPECT_EQ(APIntOps::RoundingSDiv(ApSMax, A2, APInt::Rounding::UP),
APIntOps::avgCeilS(A0, ApSMax));

EXPECT_EQ(ApSMax, APIntOps::avgFloorS(ApSMax, ApSMax));
EXPECT_EQ(ApSMax, APIntOps::avgCeilS(ApSMax, ApSMax));
}

TEST(APIntTest, umul_ov) {
const std::pair<uint64_t, uint64_t> Overflows[] = {
{0x8000000000000000, 2},
Expand Down