Skip to content

Commit ea764b7

Browse files
committed
[ADT] Add signed and unsigned mulHi and mulLo to APInt
This addresses issue #84207
1 parent d9c8550 commit ea764b7

File tree

6 files changed

+84
-35
lines changed

6 files changed

+84
-35
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,6 +2193,12 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
21932193
return A.uge(B) ? (A - B) : (B - A);
21942194
}
21952195

2196+
/// Return the high bits of the signed multiplication of C1 and C2
2197+
APInt mulhs(const APInt &C1, const APInt &C2);
2198+
2199+
/// Return the high bits of the unsigned multiplication of C1 and C2
2200+
APInt mulhu(const APInt &C1, const APInt &C2);
2201+
21962202
/// Compute GCD of two unsigned APInt values.
21972203
///
21982204
/// This function returns the greatest common divisor of the two APInt values

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6009,18 +6009,10 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
60096009
if (!C2.getBoolValue())
60106010
break;
60116011
return C1.srem(C2);
6012-
case ISD::MULHS: {
6013-
unsigned FullWidth = C1.getBitWidth() * 2;
6014-
APInt C1Ext = C1.sext(FullWidth);
6015-
APInt C2Ext = C2.sext(FullWidth);
6016-
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
6017-
}
6018-
case ISD::MULHU: {
6019-
unsigned FullWidth = C1.getBitWidth() * 2;
6020-
APInt C1Ext = C1.zext(FullWidth);
6021-
APInt C2Ext = C2.zext(FullWidth);
6022-
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
6023-
}
6012+
case ISD::MULHS:
6013+
return APIntOps::mulhs(C1, C2);
6014+
case ISD::MULHU:
6015+
return APIntOps::mulhu(C1, C2);
60246016
case ISD::AVGFLOORS: {
60256017
unsigned FullWidth = C1.getBitWidth() + 1;
60266018
APInt C1Ext = C1.sext(FullWidth);

llvm/lib/Support/APInt.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3094,3 +3094,19 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
30943094
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
30953095
}
30963096
}
3097+
3098+
APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
3099+
assert(C1.getBitWidth() >= C2.getBitWidth());
3100+
unsigned FullWidth = C1.getBitWidth() * 2;
3101+
APInt C1Ext = C1.sext(FullWidth);
3102+
APInt C2Ext = C2.sext(FullWidth);
3103+
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
3104+
}
3105+
3106+
APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
3107+
assert(C1.getBitWidth() >= C2.getBitWidth());
3108+
unsigned FullWidth = C1.getBitWidth() * 2;
3109+
APInt C1Ext = C1.zext(FullWidth);
3110+
APInt C2Ext = C2.zext(FullWidth);
3111+
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
3112+
}

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "llvm/ADT/ArrayRef.h"
1111
#include "llvm/ADT/DenseMap.h"
1212
#include "llvm/ADT/SmallString.h"
13+
#include "llvm/ADT/StringExtras.h"
1314
#include "llvm/ADT/Twine.h"
1415
#include "llvm/Support/Alignment.h"
1516
#include "gtest/gtest.h"
@@ -2805,6 +2806,58 @@ TEST(APIntTest, multiply) {
28052806
EXPECT_EQ(64U, i96.countr_zero());
28062807
}
28072808

2809+
TEST(APIntOpsTest, Mulh) {
2810+
2811+
// Unsigned
2812+
2813+
// 32 bits
2814+
APInt i32a(32, 0x0001'E235);
2815+
APInt i32b(32, 0xF623'55AD);
2816+
EXPECT_EQ(0x0001'CFA1, APIntOps::mulhu(i32a, i32b));
2817+
2818+
// 64 bits
2819+
APInt i64a(64, 0x1234'5678'90AB'CDEF);
2820+
APInt i64b(64, 0xFEDC'BA09'8765'4321);
2821+
EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulhu(i64a, i64b));
2822+
2823+
// 128 bits
2824+
APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16);
2825+
APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16);
2826+
APInt i128Res = APIntOps::mulhu(i128a, i128b);
2827+
EXPECT_EQ(APInt(128, "121FA000A3723A57E68984312C3A8D7E", 16), i128Res);
2828+
2829+
// Signed
2830+
2831+
// 32 bits
2832+
APInt i32c(32, 0x1234'5678); // +ve
2833+
APInt i32d(32, 0x10AB'CDEF); // +ve
2834+
APInt i32e(32, 0xFEDC'BA09); // -ve
2835+
2836+
EXPECT_EQ(0x012F'7D02, APIntOps::mulhs(i32c, i32d));
2837+
EXPECT_EQ(0xFFEB'4988, APIntOps::mulhs(i32c, i32e));
2838+
EXPECT_EQ(0x0001'4B68, APIntOps::mulhs(i32e, i32e));
2839+
2840+
// 64 bits
2841+
APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve
2842+
APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve
2843+
APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve
2844+
2845+
EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulhs(i64c, i64d));
2846+
EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulhs(i64c, i64e));
2847+
EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulhs(i64e, i64e));
2848+
2849+
// 128 bits
2850+
APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve
2851+
APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve
2852+
APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve
2853+
2854+
i128Res = APIntOps::mulhs(i128c, i128d);
2855+
EXPECT_EQ(APInt(128, "14B66DC328E10C1FE303DF9EA0B2529", 16), i128Res);
2856+
2857+
i128Res = APIntOps::mulhs(i128c, i128e);
2858+
EXPECT_EQ(APInt(128, "FFEB498812C66C68D4552DB89B8EBF8F", 16), i128Res);
2859+
}
2860+
28082861
TEST(APIntTest, RoundingUDiv) {
28092862
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
28102863
APInt A(8, Ai);

llvm/unittests/Support/DivisionByConstantTest.cpp

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,6 @@ template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
2121
} while (++N != 0);
2222
}
2323

24-
APInt MULHS(APInt X, APInt Y) {
25-
unsigned Bits = X.getBitWidth();
26-
unsigned WideBits = 2 * Bits;
27-
return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
28-
}
29-
3024
APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
3125
SignedDivisionByConstantInfo Magics) {
3226
unsigned Bits = Numerator.getBitWidth();
@@ -48,7 +42,7 @@ APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
4842
}
4943

5044
// Multiply the numerator by the magic value.
51-
APInt Q = MULHS(Numerator, Magics.Magic);
45+
APInt Q = APIntOps::mulhs(Numerator, Magics.Magic);
5246

5347
// (Optionally) Add/subtract the numerator using Factor.
5448
Factor = Numerator * Factor;
@@ -89,12 +83,6 @@ TEST(SignedDivisionByConstantTest, Test) {
8983
}
9084
}
9185

92-
APInt MULHU(APInt X, APInt Y) {
93-
unsigned Bits = X.getBitWidth();
94-
unsigned WideBits = 2 * Bits;
95-
return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
96-
}
97-
9886
APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
9987
bool LZOptimization,
10088
bool AllowEvenDivisorOptimization, bool ForceNPQ,
@@ -129,7 +117,7 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
129117
APInt Q = Numerator.lshr(PreShift);
130118

131119
// Multiply the numerator by the magic value.
132-
Q = MULHU(Q, Magics.Magic);
120+
Q = APIntOps::mulhu(Q, Magics.Magic);
133121

134122
if (UseNPQ || ForceNPQ) {
135123
APInt NPQ = Numerator - Q;
@@ -138,7 +126,7 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
138126
// MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
139127
APInt NPQ_Scalar = NPQ.lshr(1);
140128
(void)NPQ_Scalar;
141-
NPQ = MULHU(NPQ, NPQFactor);
129+
NPQ = APIntOps::mulhu(NPQ, NPQFactor);
142130
assert(!UseNPQ || NPQ == NPQ_Scalar);
143131

144132
Q = NPQ + Q;

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -537,19 +537,13 @@ TEST(KnownBitsTest, BinaryExhaustive) {
537537
[](const KnownBits &Known1, const KnownBits &Known2) {
538538
return KnownBits::mulhs(Known1, Known2);
539539
},
540-
[](const APInt &N1, const APInt &N2) {
541-
unsigned Bits = N1.getBitWidth();
542-
return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits);
543-
},
540+
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
544541
checkCorrectnessOnlyBinary);
545542
testBinaryOpExhaustive(
546543
[](const KnownBits &Known1, const KnownBits &Known2) {
547544
return KnownBits::mulhu(Known1, Known2);
548545
},
549-
[](const APInt &N1, const APInt &N2) {
550-
unsigned Bits = N1.getBitWidth();
551-
return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits);
552-
},
546+
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
553547
checkCorrectnessOnlyBinary);
554548
}
555549

0 commit comments

Comments
 (0)