-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[ADT] Add signed and unsigned mulh to APInt #84719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-llvm-adt Author: Atousa Duprat (Atousa) ChangesThis addresses issue #84207 Full diff: https://github.com/llvm/llvm-project/pull/84719.diff 4 Files Affected:
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..e3ce01af20e7c6 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2193,6 +2193,18 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
return A.uge(B) ? (A - B) : (B - A);
}
+/// Return the high bits of the signed multiplication of C1 and C2
+APInt mulHiS(const APInt &C1, const APInt &C2);
+
+/// Return the high bits of the unsigned multiplication of C1 and C2
+APInt mulHiU(const APInt &C1, const APInt &C2);
+
+/// Return the low bits of the signed multiplication of C1 and C2
+APInt mulLoS(const APInt &C1, const APInt &C2);
+
+/// Return the low bits of the unsigned multiplication of C1 and C2
+APInt mulLoU(const APInt &C1, const APInt &C2);
+
/// Compute GCD of two unsigned APInt values.
///
/// This function returns the greatest common divisor of the two APInt values
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index f7ace79e8c51d4..53697e1ffb5b71 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6009,18 +6009,10 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
- case ISD::MULHS: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.sext(FullWidth);
- APInt C2Ext = C2.sext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
- case ISD::MULHU: {
- unsigned FullWidth = C1.getBitWidth() * 2;
- APInt C1Ext = C1.zext(FullWidth);
- APInt C2Ext = C2.zext(FullWidth);
- return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
- }
+ case ISD::MULHS:
+ return APIntOps::mulHiS(C1, C2);
+ case ISD::MULHU:
+ return APIntOps::mulHiU(C1, C2);
case ISD::AVGFLOORS: {
unsigned FullWidth = C1.getBitWidth() + 1;
APInt C1Ext = C1.sext(FullWidth);
@@ -6706,8 +6698,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
case ISD::UDIV:
case ISD::UREM:
- case ISD::MULHU:
case ISD::MULHS:
+ case ISD::MULHU:
case ISD::SDIV:
case ISD::SREM:
case ISD::SADDSAT:
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..46c469cafa88cc 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -3094,3 +3094,31 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
}
}
+
+APInt APIntOps::mulHiS(const APInt &C1, const APInt &C2) {
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.sext(FullWidth);
+ APInt C2Ext = C2.sext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulHiU(const APInt &C1, const APInt &C2) {
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.zext(FullWidth);
+ APInt C2Ext = C2.zext(FullWidth);
+ return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
+}
+
+APInt APIntOps::mulLoS(const APInt &C1, const APInt &C2) {
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.sext(FullWidth);
+ APInt C2Ext = C2.sext(FullWidth);
+ return (C1Ext * C2Ext).trunc(C1.getBitWidth());
+}
+
+APInt APIntOps::mulLoU(const APInt &C1, const APInt &C2) {
+ unsigned FullWidth = C1.getBitWidth() * 2;
+ APInt C1Ext = C1.zext(FullWidth);
+ APInt C2Ext = C2.zext(FullWidth);
+ return (C1Ext * C2Ext).trunc(C1.getBitWidth());
+}
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..1597ac6f331d47 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -10,6 +10,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Alignment.h"
#include "gtest/gtest.h"
@@ -2805,6 +2806,89 @@ TEST(APIntTest, multiply) {
EXPECT_EQ(64U, i96.countr_zero());
}
+TEST(APIntOpsTest, MulHiLo) {
+
+ // Unsigned
+
+ // 32 bits
+ APInt i32a(32, 0x0001'E235);
+ APInt i32b(32, 0xF623'55AD);
+ EXPECT_EQ(0x0001'CFA1, APIntOps::mulHiU(i32a, i32b));
+ EXPECT_EQ(0x7CA0'76D1, APIntOps::mulLoU(i32a, i32b));
+
+ // 64 bits
+ APInt i64a(64, 0x1234'5678'90AB'CDEF);
+ APInt i64b(64, 0xFEDC'BA09'8765'4321);
+ EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulHiU(i64a, i64b));
+ EXPECT_EQ(0xC24A'442F'E556'18CF, APIntOps::mulLoU(i64a, i64b));
+
+ // 128 bits
+ APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16);
+ APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16);
+ APInt i128ResHi = APIntOps::mulHiU(i128a, i128b);
+ std::string strResHi = toString(i128ResHi, 16, false, true, true, true);
+ EXPECT_STREQ("0x121F'A000'A372'3A57'E689'8431'2C3A'8D7E", strResHi.c_str());
+ APInt i128ResLo = APIntOps::mulLoU(i128a, i128b);
+ std::string strResLo = toString(i128ResLo, 16, false, true, true, true);
+ EXPECT_STREQ("0x96B4'2860'6E1E'6BF5'C24A'442F'E556'18CF", strResLo.c_str());
+
+ // Signed
+
+ // 32 bits
+ APInt i32c(32, 0x1234'5678); // +ve
+ APInt i32d(32, 0x10AB'CDEF); // +ve
+ APInt i32e(32, 0xFEDC'BA09); // -ve
+
+ EXPECT_EQ(0x012F'7D02, APIntOps::mulHiS(i32c, i32d));
+ EXPECT_EQ(0x2A42'D208, APIntOps::mulLoS(i32c, i32d));
+
+ EXPECT_EQ(0xFFEB'4988, APIntOps::mulHiS(i32c, i32e));
+ EXPECT_EQ(0x09CA'3A38, APIntOps::mulLoS(i32c, i32e));
+
+ EXPECT_EQ(0x0001'4B68, APIntOps::mulHiS(i32e, i32e));
+ EXPECT_EQ(0x22A9'1451, APIntOps::mulLoS(i32e, i32e));
+
+ // 64 bits
+ APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve
+ APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve
+ APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve
+
+ EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulHiS(i64c, i64d));
+ EXPECT_EQ(0xFB99'7041'84EF'03A6, APIntOps::mulLoS(i64c, i64d));
+
+ EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulHiS(i64c, i64e));
+ EXPECT_EQ(0xC24A'442F'E556'18CF, APIntOps::mulLoS(i64c, i64e));
+
+ EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulHiS(i64e, i64e));
+ EXPECT_EQ(0xCEFE'A12C'D7A4'4A41, APIntOps::mulLoS(i64e, i64e));
+
+ // 128 bits
+ APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve
+ APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve
+ APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve
+
+ i128ResHi = APIntOps::mulHiS(i128c, i128d);
+ strResHi = toString(i128ResHi, 16, false, true, true, true);
+ EXPECT_STREQ("0x14B'66DC'328E'10C1'FE30'3DF9'EA0B'2529", strResHi.c_str());
+ i128ResLo = APIntOps::mulLoS(i128c, i128d);
+ strResLo = toString(i128ResLo, 16, false, true, true, true);
+ EXPECT_STREQ("0xF87E'475F'3C6C'180D'FB99'7041'84EF'03A6", strResLo.c_str());
+
+ i128ResHi = APIntOps::mulHiS(i128c, i128e);
+ strResHi = toString(i128ResHi, 16, false, true, true, true);
+ EXPECT_STREQ("0xFFEB'4988'12C6'6C68'D455'2DB8'9B8E'BF8F", strResHi.c_str());
+ i128ResLo = APIntOps::mulLoS(i128c, i128e);
+ strResLo = toString(i128ResLo, 16, false, true, true, true);
+ EXPECT_STREQ("0x96B4'2860'6E1E'6BF5'C24A'442F'E556'18CF", strResLo.c_str());
+
+ i128ResHi = APIntOps::mulHiS(i128e, i128e);
+ strResHi = toString(i128ResHi, 16, false, true, true, true);
+ EXPECT_STREQ("0x1'4B68'2174'FA18'CCBA'AC10'2958'C4B5", strResHi.c_str());
+ i128ResLo = APIntOps::mulLoS(i128e, i128e);
+ strResLo = toString(i128ResLo, 16, false, true, true, true);
+ EXPECT_STREQ("0x9BB8'01D4'DF88'14DC'CEFE'A12C'D7A4'4A41", strResLo.c_str());
+}
+
TEST(APIntTest, RoundingUDiv) {
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
APInt A(8, Ai);
|
llvm/lib/Support/APInt.cpp
Outdated
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); | ||
} | ||
|
||
APInt APIntOps::mulLoS(const APInt &C1, const APInt &C2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this equivalent to mul
? The lower bits of the result don't care about the extension at all.
llvm/unittests/ADT/APIntTest.cpp
Outdated
APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16); | ||
APInt i128ResHi = APIntOps::mulHiU(i128a, i128b); | ||
std::string strResHi = toString(i128ResHi, 16, false, true, true, true); | ||
EXPECT_STREQ("0x121F'A000'A372'3A57'E689'8431'2C3A'8D7E", strResHi.c_str()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not regular EXPECT_EQ and avoid the .c_str?
llvm/unittests/ADT/APIntTest.cpp
Outdated
APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16); | ||
APInt i128ResHi = APIntOps::mulHiU(i128a, i128b); | ||
std::string strResHi = toString(i128ResHi, 16, false, true, true, true); | ||
EXPECT_STREQ("0x121F'A000'A372'3A57'E689'8431'2C3A'8D7E", strResHi.c_str()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we comparing strings? Why not create a 128-bit APInt for the expected result and compare that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the "lo" functions as they are identical to operator*
. For the "hi" functions I'd prefer the names mulhu and mulhs, since these names are already used elsewhere in LLVM (in KnownBits and in SelectionDAG).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please can you update the KnownBitsTest mulhu/mulhs exhaustive tests to use these for the APInt references
llvm/include/llvm/ADT/APInt.h
Outdated
APInt mulHiS(const APInt &C1, const APInt &C2); | ||
|
||
/// Return the high bits of the unsigned multiplication of C1 and C2 | ||
APInt mulHiU(const APInt &C1, const APInt &C2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use the mulhs/mulhu names - as thats the names we already use elsewhere (KnownBIts, ISD etc.)
llvm/include/llvm/ADT/APInt.h
Outdated
APInt mulLoS(const APInt &C1, const APInt &C2); | ||
|
||
/// Return the low bits of the unsigned multiplication of C1 and C2 | ||
APInt mulLoU(const APInt &C1, const APInt &C2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are just mul - which we already have - remove them
case ISD::MULHS: | ||
case ISD::MULHU: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Superfluous - remove it
llvm/include/llvm/ADT/APInt.h
Outdated
@@ -2193,6 +2193,12 @@ inline const APInt absdiff(const APInt &A, const APInt &B) { | |||
return A.uge(B) ? (A - B) : (B - A); | |||
} | |||
|
|||
/// Return the high bits of the signed multiplication of C1 and C2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uber-nit: comments should be full sentences, per coding standards.
/// Return the high bits of the signed multiplication of C1 and C2 | |
/// Return the high bits of the signed multiplication of C1 and C2. |
llvm/unittests/ADT/APIntTest.cpp
Outdated
@@ -10,6 +10,7 @@ | |||
#include "llvm/ADT/ArrayRef.h" | |||
#include "llvm/ADT/DenseMap.h" | |||
#include "llvm/ADT/SmallString.h" | |||
#include "llvm/ADT/StringExtras.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - LGTM overall, just a couple of nits inline.
llvm/lib/Support/APInt.cpp
Outdated
@@ -3094,3 +3094,19 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src, | |||
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes); | |||
} | |||
} | |||
|
|||
APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) { | |||
assert(C1.getBitWidth() >= C2.getBitWidth()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(style) Assertions should include a message string as well:
assert(C1.getBitWidth() >= C2.getBitWidth() && "BitWidth mismatch");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this >=
instead of ==
? I really don't like that! Perhaps there is a bug somewhere else that should be fixed first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DivisionByConstantTest code allows different bitwidths - which isn't great - we should probably drop the DivisionByConstantTest diff and assert for ==
- we could then fix DivisionByConstantTest in a future patch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I'd prefer that. If the bit widths don't match then it is no longer clear (to me) what these functions are supposed to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please raise a ticket for the DivisionByConstantTest. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also address things I raised on the other PR for mulh?
#84609 (review) and the other comments that follow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One final minor
llvm/lib/Support/APInt.cpp
Outdated
@@ -3094,3 +3094,19 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src, | |||
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes); | |||
} | |||
} | |||
|
|||
APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) { | |||
assert(C1.getBitWidth() == C2.getBitWidth()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(style) Assertion message: assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
llvm/lib/Support/APInt.cpp
Outdated
} | ||
|
||
APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) { | ||
assert(C1.getBitWidth() == C2.getBitWidth()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(style) Assertion message: assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR subject line doesn't match the function names anymore. Could you update it?
Thanks for all the fixes so far. In general looks good, my only remaining comments are related to the documentation.
llvm/include/llvm/ADT/APInt.h
Outdated
/// Return the high bits of the signed multiplication of C1 and C2. | ||
APInt mulhs(const APInt &C1, const APInt &C2); | ||
|
||
/// Return the high bits of the unsigned multiplication of C1 and C2. | ||
APInt mulhu(const APInt &C1, const APInt &C2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could use some more explanation about what this actually calculates. See #84609 (comment) and #84609 (comment).
Also, I think we could make the first line of description a little bit more precise, e.g.:
/// Return the high bits of the signed multiplication of C1 and C2. | |
APInt mulhs(const APInt &C1, const APInt &C2); | |
/// Return the high bits of the unsigned multiplication of C1 and C2. | |
APInt mulhu(const APInt &C1, const APInt &C2); | |
/// Performs (2*N)-bit multiplication on sign-extended operands. | |
/// Returns the high N bits of the multiplication result. | |
APInt mulhs(const APInt &C1, const APInt &C2); | |
/// Performs (2*N)-bit multiplication on zero-extended operands. | |
/// Returns the high N bits of the multiplication result. | |
APInt mulhu(const APInt &C1, const APInt &C2); |
See https://mlir.llvm.org/docs/Dialects/ArithOps/#arithmulsi_extended-arithmulsiextendedop
97544fd
to
a062515
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - cheers!
@kuhar Are you happy with the mlir changes? |
@@ -434,9 +434,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, | |||
// Invoke the constant fold helper again to calculate the 'high' result. | |||
Attribute highAttr = constFoldBinaryOp<IntegerAttr>( | |||
adaptor.getOperands(), [](const APInt &a, const APInt &b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this can just be ...(adaptor.getOperands(), llvm::APIntOps::mulhs);
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this can just be
...(adaptor.getOperands(), llvm::APIntOps::mulhs);
?
Not sure, I think it is clear as it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend any further refactor changes like that as followups for somebody (not necessarily @Atousa) - to ensure the changes in this PR are just the move to the APIntOps::mulh* helpers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the changes.
Note that this could also be implement without btiwidth extension, similar to #85212. I don't know if the performance of the implementation matters much here, these functions are not on any fast path on the MLIR side. If you need a faster implementation, I wrote one for WebGPU:
Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter, |
3a2283a
to
cf7985f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please can you rebase now that the #84431 has landed
ping @Atousa |
I missed that. I thought it has been merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Atousa Please can you add back the KnownBitsTest.cpp change?
This addresses issue llvm#84207
@RKSimon I don't know why Linux build failed ? |
I think it was just a ccache failure - I've merged the patch and will keep an eye on it for failures |
Fixes #84207