Skip to content

[ADT] Add signed and unsigned mulh to APInt #84719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/ADT/APInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,14 @@ APInt avgCeilS(const APInt &C1, const APInt &C2);
/// Compute the ceil of the unsigned average of C1 and C2
APInt avgCeilU(const APInt &C1, const APInt &C2);

/// Performs (2*N)-bit multiplication on sign-extended operands.
/// Returns the high N bits of the multiplication result.
APInt mulhs(const APInt &C1, const APInt &C2);

/// Performs (2*N)-bit multiplication on zero-extended operands.
/// Returns the high N bits of the multiplication result.
APInt mulhu(const APInt &C1, const APInt &C2);

/// Compute GCD of two unsigned APInt values.
///
/// This function returns the greatest common divisor of the two APInt values
Expand Down
17 changes: 4 additions & 13 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6073,18 +6073,6 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
if (!C2.getBoolValue())
break;
return C1.srem(C2);
case ISD::MULHS: {
unsigned FullWidth = C1.getBitWidth() * 2;
APInt C1Ext = C1.sext(FullWidth);
APInt C2Ext = C2.sext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}
case ISD::MULHU: {
unsigned FullWidth = C1.getBitWidth() * 2;
APInt C1Ext = C1.zext(FullWidth);
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}
case ISD::AVGFLOORS:
return APIntOps::avgFloorS(C1, C2);
case ISD::AVGFLOORU:
Expand All @@ -6097,10 +6085,13 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
return APIntOps::abds(C1, C2);
case ISD::ABDU:
return APIntOps::abdu(C1, C2);
case ISD::MULHS:
return APIntOps::mulhs(C1, C2);
case ISD::MULHU:
return APIntOps::mulhu(C1, C2);
}
return std::nullopt;
}

// Handle constant folding with UNDEF.
// TODO: Handle more cases.
static std::optional<APInt> FoldValueWithUndef(unsigned Opcode, const APInt &C1,
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Support/APInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3121,3 +3121,19 @@ APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
// Return ceil((C1 + C2) / 2)
return (C1 | C2) - (C1 ^ C2).lshr(1);
}

APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) {
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
unsigned FullWidth = C1.getBitWidth() * 2;
APInt C1Ext = C1.sext(FullWidth);
APInt C2Ext = C2.sext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}

APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths");
unsigned FullWidth = C1.getBitWidth() * 2;
APInt C1Ext = C1.zext(FullWidth);
APInt C2Ext = C2.zext(FullWidth);
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
}
52 changes: 52 additions & 0 deletions llvm/unittests/ADT/APIntTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2841,6 +2841,58 @@ TEST(APIntTest, multiply) {
EXPECT_EQ(64U, i96.countr_zero());
}

TEST(APIntOpsTest, Mulh) {

// Unsigned

// 32 bits
APInt i32a(32, 0x0001'E235);
APInt i32b(32, 0xF623'55AD);
EXPECT_EQ(0x0001'CFA1, APIntOps::mulhu(i32a, i32b));

// 64 bits
APInt i64a(64, 0x1234'5678'90AB'CDEF);
APInt i64b(64, 0xFEDC'BA09'8765'4321);
EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulhu(i64a, i64b));

// 128 bits
APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16);
APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16);
APInt i128Res = APIntOps::mulhu(i128a, i128b);
EXPECT_EQ(APInt(128, "121FA000A3723A57E68984312C3A8D7E", 16), i128Res);

// Signed

// 32 bits
APInt i32c(32, 0x1234'5678); // +ve
APInt i32d(32, 0x10AB'CDEF); // +ve
APInt i32e(32, 0xFEDC'BA09); // -ve

EXPECT_EQ(0x012F'7D02, APIntOps::mulhs(i32c, i32d));
EXPECT_EQ(0xFFEB'4988, APIntOps::mulhs(i32c, i32e));
EXPECT_EQ(0x0001'4B68, APIntOps::mulhs(i32e, i32e));

// 64 bits
APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve
APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve
APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve

EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulhs(i64c, i64d));
EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulhs(i64c, i64e));
EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulhs(i64e, i64e));

// 128 bits
APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve
APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve
APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve

i128Res = APIntOps::mulhs(i128c, i128d);
EXPECT_EQ(APInt(128, "14B66DC328E10C1FE303DF9EA0B2529", 16), i128Res);

i128Res = APIntOps::mulhs(i128c, i128e);
EXPECT_EQ(APInt(128, "FFEB498812C66C68D4552DB89B8EBF8F", 16), i128Res);
}

TEST(APIntTest, RoundingUDiv) {
for (uint64_t Ai = 1; Ai <= 255; Ai++) {
APInt A(8, Ai);
Expand Down
10 changes: 2 additions & 8 deletions llvm/unittests/Support/KnownBitsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,17 +553,11 @@ TEST(KnownBitsTest, BinaryExhaustive) {
checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(
KnownBits::mulhs,
[](const APInt &N1, const APInt &N2) {
unsigned Bits = N1.getBitWidth();
return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits);
},
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); },
checkCorrectnessOnlyBinary);
testBinaryOpExhaustive(
KnownBits::mulhu,
[](const APInt &N1, const APInt &N2) {
unsigned Bits = N1.getBitWidth();
return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits);
},
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
checkCorrectnessOnlyBinary);
}

Expand Down
8 changes: 2 additions & 6 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
// Invoke the constant fold helper again to calculate the 'high' result.
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can just be ...(adaptor.getOperands(), llvm::APIntOps::mulhs);?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can just be ...(adaptor.getOperands(), llvm::APIntOps::mulhs);?

Not sure, I think it is clear as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend any further refactor changes like that as followups for somebody (not necessarily @Atousa) - to ensure the changes in this PR are just the move to the APIntOps::mulh* helpers.

unsigned bitWidth = a.getBitWidth();
APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
return fullProduct.extractBits(bitWidth, bitWidth);
return llvm::APIntOps::mulhs(a, b);
});
assert(highAttr && "Unexpected constant-folding failure");

Expand Down Expand Up @@ -491,9 +489,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
// Invoke the constant fold helper again to calculate the 'high' result.
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
unsigned bitWidth = a.getBitWidth();
APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
return fullProduct.extractBits(bitWidth, bitWidth);
return llvm::APIntOps::mulhu(a, b);
});
assert(highAttr && "Unexpected constant-folding failure");

Expand Down
7 changes: 2 additions & 5 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,11 @@ struct MulExtendedFold final : OpRewritePattern<MulOp> {

auto highBits = constFoldBinaryOp<IntegerAttr>(
{lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
unsigned bitWidth = a.getBitWidth();
APInt c;
if (IsSigned) {
c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
return llvm::APIntOps::mulhs(a, b);
} else {
c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
return llvm::APIntOps::mulhu(a, b);
}
return c.extractBits(bitWidth, bitWidth); // Extract high result
});

if (!highBits)
Expand Down