-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[ADT] Add implementations for mulhs and mulhu to APInt #84609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-llvm-support Author: Shourya Goel (Sh0g0-1758) ChangesFixes: #84207 Full diff: https://github.com/llvm/llvm-project/pull/84609.diff 5 Files Affected:
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..711eb3a9c3fc6d 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,6 +2193,12 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
return A.uge(B) ? (A - B) : (B - A);
}
+/// Compute the higher order bits of unsigned multiplication of two APInts
+APInt mulhu(const APInt &C1, const APInt &C2);
+
+/// Compute the higher order bits of signed multiplication of two APInts
+APInt mulhs(const APInt &C1, const APInt &C2);
+
/// Compute GCD of two unsigned APInt values.
///
/// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 50f53bbb04b62d..4f533b7d055129 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6015,17 +6015,11 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
- case ISD::MULHS: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.sext(FullWidth);
- APInt C2Ext = C2.sext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
case ISD::MULHU: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.zext(FullWidth);
- APInt C2Ext = C2.zext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+ return APIntOps::mulhu(C1, C2);
+ }
+ case ISD::MULHS: {
+ return APIntOps::mulhs(C1, C2);
}
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..9ce10ada67e9e1 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3067,6 +3067,22 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,
}
}
+APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
+ // Return higher order bits for unsigned (C1 * C2)
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.zext(FullWidth);
+ APInt C2Ext = C2.zext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
+ // Return higher order bits for signed (C1 * C2)
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.sext(FullWidth);
+ APInt C2Ext = C2.sext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
/// LoadIntFromMemory - Loads the integer stored in the LoadBytes bytes starting
/// from Src into IntVal, which is assumed to be wide enough and to hold zero.
void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..9995aeaff92871 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2805,6 +2805,26 @@ TEST(APIntTest, multiply) {
EXPECT_EQ(64U, i96.countr_zero());
}
+TEST(APIntTest, Hmultiply) {
+ APInt i1048576(32, 1048576);
+
+ EXPECT_EQ(APInt(32, 256), APIntOps::mulhu(i1048576, i1048576));
+
+ APInt i16777216(32, 16777216);
+ APInt i32768(32, 32768);
+
+ EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i16777216, i32768));
+ EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i32768, i16777216));
+
+ APInt i2097152(32, -2097152);
+ APInt i524288(32, 524288);
+
+ EXPECT_EQ(APInt(32, 1024), APIntOps::mulhs(i2097152, i2097152));
+
+ EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i2097152, i524288));
+ EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i524288, i2097152));
+}
+
TEST(APIntTest, RoundingUDiv) {
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
APInt A(8, Ai);
diff --git a/llvm/unittests/Support/DivisionByConstantTest.cpp b/llvm/unittests/Support/DivisionByConstantTest.cpp
index 2b17f98bb75b2f..8e0c78fe85654a 100644
--- a/llvm/unittests/Support/DivisionByConstantTest.cpp
+++ b/llvm/unittests/Support/DivisionByConstantTest.cpp
@@ -21,12 +21,6 @@ template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
} while (++N != 0);
}
-APInt MULHS(APInt X, APInt Y) {
- unsigned Bits = X.getBitWidth();
- unsigned WideBits = 2 * Bits;
- return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
SignedDivisionByConstantInfo Magics) {
unsigned Bits = Numerator.getBitWidth();
@@ -48,7 +42,7 @@ APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
}
// Multiply the numerator by the magic value.
- APInt Q = MULHS(Numerator, Magics.Magic);
+ APInt Q = APIntOps::mulhs(Numerator, Magics.Magic);
// (Optionally) Add/subtract the numerator using Factor.
Factor = Numerator * Factor;
@@ -89,12 +83,6 @@ TEST(SignedDivisionByConstantTest, Test) {
}
}
-APInt MULHU(APInt X, APInt Y) {
- unsigned Bits = X.getBitWidth();
- unsigned WideBits = 2 * Bits;
- return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
bool LZOptimization,
bool AllowEvenDivisorOptimization, bool ForceNPQ,
@@ -129,16 +117,16 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
APInt Q = Numerator.lshr(PreShift);
// Multiply the numerator by the magic value.
- Q = MULHU(Q, Magics.Magic);
+ Q = APIntOps::mulhu(Q, Magics.Magic);
if (UseNPQ || ForceNPQ) {
APInt NPQ = Numerator - Q;
// For vectors we might have a mix of non-NPQ/NPQ paths, so use
- // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
+ // mulhu to act as a SRL-by-1 for NPQ, else multiply by zero.
APInt NPQ_Scalar = NPQ.lshr(1);
(void)NPQ_Scalar;
- NPQ = MULHU(NPQ, NPQFactor);
+ NPQ = APIntOps::mulhu(NPQ, NPQFactor);
assert(!UseNPQ || NPQ == NPQ_Scalar);
Q = NPQ + Q;
|
@llvm/pr-subscribers-llvm-adt Author: Shourya Goel (Sh0g0-1758) ChangesFixes: #84207 Full diff: https://github.com/llvm/llvm-project/pull/84609.diff 5 Files Affected:
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..711eb3a9c3fc6d 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,6 +2193,12 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
return A.uge(B) ? (A - B) : (B - A);
}
+/// Compute the higher order bits of unsigned multiplication of two APInts
+APInt mulhu(const APInt &C1, const APInt &C2);
+
+/// Compute the higher order bits of signed multiplication of two APInts
+APInt mulhs(const APInt &C1, const APInt &C2);
+
/// Compute GCD of two unsigned APInt values.
///
/// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 50f53bbb04b62d..4f533b7d055129 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6015,17 +6015,11 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
- case ISD::MULHS: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.sext(FullWidth);
- APInt C2Ext = C2.sext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
case ISD::MULHU: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.zext(FullWidth);
- APInt C2Ext = C2.zext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+ return APIntOps::mulhu(C1, C2);
+ }
+ case ISD::MULHS: {
+ return APIntOps::mulhs(C1, C2);
}
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..9ce10ada67e9e1 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3067,6 +3067,22 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,
}
}
+APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
+ // Return higher order bits for unsigned (C1 * C2)
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.zext(FullWidth);
+ APInt C2Ext = C2.zext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
+ // Return higher order bits for signed (C1 * C2)
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.sext(FullWidth);
+ APInt C2Ext = C2.sext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
/// LoadIntFromMemory - Loads the integer stored in the LoadBytes bytes starting
/// from Src into IntVal, which is assumed to be wide enough and to hold zero.
void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..9995aeaff92871 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2805,6 +2805,26 @@ TEST(APIntTest, multiply) {
EXPECT_EQ(64U, i96.countr_zero());
}
+TEST(APIntTest, Hmultiply) {
+ APInt i1048576(32, 1048576);
+
+ EXPECT_EQ(APInt(32, 256), APIntOps::mulhu(i1048576, i1048576));
+
+ APInt i16777216(32, 16777216);
+ APInt i32768(32, 32768);
+
+ EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i16777216, i32768));
+ EXPECT_EQ(APInt(32, 128), APIntOps::mulhu(i32768, i16777216));
+
+ APInt i2097152(32, -2097152);
+ APInt i524288(32, 524288);
+
+ EXPECT_EQ(APInt(32, 1024), APIntOps::mulhs(i2097152, i2097152));
+
+ EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i2097152, i524288));
+ EXPECT_EQ(APInt(32, -256), APIntOps::mulhs(i524288, i2097152));
+}
+
TEST(APIntTest, RoundingUDiv) {
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
APInt A(8, Ai);
diff --git a/llvm/unittests/Support/DivisionByConstantTest.cpp b/llvm/unittests/Support/DivisionByConstantTest.cpp
index 2b17f98bb75b2f..8e0c78fe85654a 100644
--- a/llvm/unittests/Support/DivisionByConstantTest.cpp
+++ b/llvm/unittests/Support/DivisionByConstantTest.cpp
@@ -21,12 +21,6 @@ template <typename Fn> static void EnumerateAPInts(unsigned Bits, Fn TestFn) {
} while (++N != 0);
}
-APInt MULHS(APInt X, APInt Y) {
- unsigned Bits = X.getBitWidth();
- unsigned WideBits = 2 * Bits;
- return (X.sext(WideBits) * Y.sext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
SignedDivisionByConstantInfo Magics) {
unsigned Bits = Numerator.getBitWidth();
@@ -48,7 +42,7 @@ APInt SignedDivideUsingMagic(APInt Numerator, APInt Divisor,
}
// Multiply the numerator by the magic value.
- APInt Q = MULHS(Numerator, Magics.Magic);
+ APInt Q = APIntOps::mulhs(Numerator, Magics.Magic);
// (Optionally) Add/subtract the numerator using Factor.
Factor = Numerator * Factor;
@@ -89,12 +83,6 @@ TEST(SignedDivisionByConstantTest, Test) {
}
}
-APInt MULHU(APInt X, APInt Y) {
- unsigned Bits = X.getBitWidth();
- unsigned WideBits = 2 * Bits;
- return (X.zext(WideBits) * Y.zext(WideBits)).lshr(Bits).trunc(Bits);
-}
-
APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
bool LZOptimization,
bool AllowEvenDivisorOptimization, bool ForceNPQ,
@@ -129,16 +117,16 @@ APInt UnsignedDivideUsingMagic(const APInt &Numerator, const APInt &Divisor,
APInt Q = Numerator.lshr(PreShift);
// Multiply the numerator by the magic value.
- Q = MULHU(Q, Magics.Magic);
+ Q = APIntOps::mulhu(Q, Magics.Magic);
if (UseNPQ || ForceNPQ) {
APInt NPQ = Numerator - Q;
// For vectors we might have a mix of non-NPQ/NPQ paths, so use
- // MULHU to act as a SRL-by-1 for NPQ, else multiply by zero.
+ // mulhu to act as a SRL-by-1 for NPQ, else multiply by zero.
APInt NPQ_Scalar = NPQ.lshr(1);
(void)NPQ_Scalar;
- NPQ = MULHU(NPQ, NPQFactor);
+ NPQ = APIntOps::mulhu(NPQ, NPQFactor);
assert(!UseNPQ || NPQ == NPQ_Scalar);
Q = NPQ + Q;
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
The issue also suggests updating KnownBitsTest.cpp
.
@jayfoad, regarding changes in |
No, it's this expression in the mulhs test (and similar for mulhu) that can be replaced with a call to the new APInt method: |
Made the changes. I was under the impression that something similar to |
@jayfoad, some Flang tests are failing for the Windows build. I know this is a pretty common issue, though not related to the PR. If there are no further suggestions, I guess it's safe to land? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of minors
@@ -3067,6 +3067,22 @@ void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst, | |||
} | |||
} | |||
|
|||
APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add same-bitwidth assertion: assert(C1.BitWidth == C2.BitWidth && "Bit widths must be the same");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be C1.getBitWidth()
. Updated the code with it.
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); | ||
} | ||
|
||
APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add same-bitwidth assertion: assert(C1.BitWidth == C2.BitWidth && "Bit widths must be the same");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
@RKSimon, reverting the assertion statements because
Which means the number of bits of the two APInts can't be same. Please correct me if I am wrong. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We have a few more places in mli that could be switched to this:
llvm-project/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Lines 251 to 264 in 110141b
auto highBits = constFoldBinaryOp<IntegerAttr>( | |
{lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) { | |
unsigned bitWidth = a.getBitWidth(); | |
APInt c; | |
if (IsSigned) { | |
c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); | |
} else { | |
c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); | |
} | |
return c.extractBits(bitWidth, bitWidth); // Extract high result | |
}); | |
if (!highBits) | |
return failure(); |
llvm-project/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Lines 430 to 445 in 110141b
// mulsi_extended(cst_a, cst_b) -> cst_low, cst_high | |
if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>( | |
adaptor.getOperands(), | |
[](const APInt &a, const APInt &b) { return a * b; })) { | |
// Invoke the constant fold helper again to calculate the 'high' result. | |
Attribute highAttr = constFoldBinaryOp<IntegerAttr>( | |
adaptor.getOperands(), [](const APInt &a, const APInt &b) { | |
unsigned bitWidth = a.getBitWidth(); | |
APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); | |
return fullProduct.extractBits(bitWidth, bitWidth); | |
}); | |
assert(highAttr && "Unexpected constant-folding failure"); | |
results.push_back(lowAttr); | |
results.push_back(highAttr); | |
return success(); |
llvm-project/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Lines 487 to 502 in 110141b
// mului_extended(cst_a, cst_b) -> cst_low, cst_high | |
if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>( | |
adaptor.getOperands(), | |
[](const APInt &a, const APInt &b) { return a * b; })) { | |
// Invoke the constant fold helper again to calculate the 'high' result. | |
Attribute highAttr = constFoldBinaryOp<IntegerAttr>( | |
adaptor.getOperands(), [](const APInt &a, const APInt &b) { | |
unsigned bitWidth = a.getBitWidth(); | |
APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); | |
return fullProduct.extractBits(bitWidth, bitWidth); | |
}); | |
assert(highAttr && "Unexpected constant-folding failure"); | |
results.push_back(lowAttr); | |
results.push_back(highAttr); | |
return success(); |
✅ With the latest revision this PR passed the C/C++ code formatter. |
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
@kuhar, I have implemented the MLIR changes and the tests are passing. For the signed comments, I tried to keep it concise, I wrote a kinda draft for the comments but it was taking up a lot of space, so decided to keep it short. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the cleanups
@RKSimon, could you drop a confirmation about assertion statements? I think apart from that, this looks good to land. |
@Sh0g0-1758 I thought the conclusion of the discussion of #84207 was that @Atousa was working on this? |
Hm.. I was under the impression that since the work here is done, she won't be working on this but from now on we keep a 7 day period before pinging on good first issues. I also don't see a PR from her yet. But to be honest, this is becoming way too complex, I really think if the work here is done, we should land it. |
and if it's not done yet, I suppose she can collaborate on this with me with the remaining work. I do think that the main goal when doing open-source contributions should be learning and collaboration and not focusing on just solving an issue. |
Closing this bcz of conversation in original issue thread. |
@Sh0g0-1758 I don't have the wider context but this PR looks like a nice code cleanup on its own. |
I see, but I suppose another dev is going to raise a similar PR soon. |
Fixes: #84207