From e9ad79cc833cf45f8ea910dbbc4de8eba43cbfab Mon Sep 17 00:00:00 2001 From: Atousa Duprat Date: Sun, 10 Mar 2024 21:53:04 -0700 Subject: [PATCH] [ADT] Add signed and unsigned mulh to APInt This addresses issue #84207 --- llvm/include/llvm/ADT/APInt.h | 8 +++ .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 17 ++---- llvm/lib/Support/APInt.cpp | 16 ++++++ llvm/unittests/ADT/APIntTest.cpp | 52 +++++++++++++++++++ llvm/unittests/Support/KnownBitsTest.cpp | 10 +--- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 8 +-- .../SPIRV/IR/SPIRVCanonicalization.cpp | 7 +-- 7 files changed, 86 insertions(+), 32 deletions(-) diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 1abea9eb24a3c..b9b39f3b9dfbc 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -2216,6 +2216,14 @@ APInt avgCeilS(const APInt &C1, const APInt &C2); /// Compute the ceil of the unsigned average of C1 and C2 APInt avgCeilU(const APInt &C1, const APInt &C2); +/// Performs (2*N)-bit multiplication on sign-extended operands. +/// Returns the high N bits of the multiplication result. +APInt mulhs(const APInt &C1, const APInt &C2); + +/// Performs (2*N)-bit multiplication on zero-extended operands. +/// Returns the high N bits of the multiplication result. +APInt mulhu(const APInt &C1, const APInt &C2); + /// Compute GCD of two unsigned APInt values. /// /// This function returns the greatest common divisor of the two APInt values diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index e8d1ac1d3a916..c52864e6dcdbd 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6073,18 +6073,6 @@ static std::optional FoldValue(unsigned Opcode, const APInt &C1, if (!C2.getBoolValue()) break; return C1.srem(C2); - case ISD::MULHS: { - unsigned FullWidth = C1.getBitWidth() * 2; - APInt C1Ext = C1.sext(FullWidth); - APInt C2Ext = C2.sext(FullWidth); - return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); - } - case ISD::MULHU: { - unsigned FullWidth = C1.getBitWidth() * 2; - APInt C1Ext = C1.zext(FullWidth); - APInt C2Ext = C2.zext(FullWidth); - return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); - } case ISD::AVGFLOORS: return APIntOps::avgFloorS(C1, C2); case ISD::AVGFLOORU: @@ -6097,10 +6085,13 @@ static std::optional FoldValue(unsigned Opcode, const APInt &C1, return APIntOps::abds(C1, C2); case ISD::ABDU: return APIntOps::abdu(C1, C2); + case ISD::MULHS: + return APIntOps::mulhs(C1, C2); + case ISD::MULHU: + return APIntOps::mulhu(C1, C2); } return std::nullopt; } - // Handle constant folding with UNDEF. // TODO: Handle more cases. static std::optional FoldValueWithUndef(unsigned Opcode, const APInt &C1, diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index 7a383c3236d9b..c20609748dc97 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -3121,3 +3121,19 @@ APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) { // Return ceil((C1 + C2) / 2) return (C1 | C2) - (C1 ^ C2).lshr(1); } + +APInt APIntOps::mulhs(const APInt &C1, const APInt &C2) { + assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths"); + unsigned FullWidth = C1.getBitWidth() * 2; + APInt C1Ext = C1.sext(FullWidth); + APInt C2Ext = C2.sext(FullWidth); + return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); +} + +APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) { + assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths"); + unsigned FullWidth = C1.getBitWidth() * 2; + APInt C1Ext = C1.zext(FullWidth); + APInt C2Ext = C2.zext(FullWidth); + return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); +} diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index 83cbb02e0f58b..d5ef63e38e279 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -2841,6 +2841,58 @@ TEST(APIntTest, multiply) { EXPECT_EQ(64U, i96.countr_zero()); } +TEST(APIntOpsTest, Mulh) { + + // Unsigned + + // 32 bits + APInt i32a(32, 0x0001'E235); + APInt i32b(32, 0xF623'55AD); + EXPECT_EQ(0x0001'CFA1, APIntOps::mulhu(i32a, i32b)); + + // 64 bits + APInt i64a(64, 0x1234'5678'90AB'CDEF); + APInt i64b(64, 0xFEDC'BA09'8765'4321); + EXPECT_EQ(0x121F'A000'A372'3A57, APIntOps::mulhu(i64a, i64b)); + + // 128 bits + APInt i128a(128, "1234567890ABCDEF1234567890ABCDEF", 16); + APInt i128b(128, "FEDCBA0987654321FEDCBA0987654321", 16); + APInt i128Res = APIntOps::mulhu(i128a, i128b); + EXPECT_EQ(APInt(128, "121FA000A3723A57E68984312C3A8D7E", 16), i128Res); + + // Signed + + // 32 bits + APInt i32c(32, 0x1234'5678); // +ve + APInt i32d(32, 0x10AB'CDEF); // +ve + APInt i32e(32, 0xFEDC'BA09); // -ve + + EXPECT_EQ(0x012F'7D02, APIntOps::mulhs(i32c, i32d)); + EXPECT_EQ(0xFFEB'4988, APIntOps::mulhs(i32c, i32e)); + EXPECT_EQ(0x0001'4B68, APIntOps::mulhs(i32e, i32e)); + + // 64 bits + APInt i64c(64, 0x1234'5678'90AB'CDEF); // +ve + APInt i64d(64, 0x1234'5678'90FE'DCBA); // +ve + APInt i64e(64, 0xFEDC'BA09'8765'4321); // -ve + + EXPECT_EQ(0x014B'66DC'328E'10C1, APIntOps::mulhs(i64c, i64d)); + EXPECT_EQ(0xFFEB'4988'12C6'6C68, APIntOps::mulhs(i64c, i64e)); + EXPECT_EQ(0x0001'4B68'2174'FA18, APIntOps::mulhs(i64e, i64e)); + + // 128 bits + APInt i128c(128, "1234567890ABCDEF1234567890ABCDEF", 16); // +ve + APInt i128d(128, "1234567890FEDCBA1234567890FEDCBA", 16); // +ve + APInt i128e(128, "FEDCBA0987654321FEDCBA0987654321", 16); // -ve + + i128Res = APIntOps::mulhs(i128c, i128d); + EXPECT_EQ(APInt(128, "14B66DC328E10C1FE303DF9EA0B2529", 16), i128Res); + + i128Res = APIntOps::mulhs(i128c, i128e); + EXPECT_EQ(APInt(128, "FFEB498812C66C68D4552DB89B8EBF8F", 16), i128Res); +} + TEST(APIntTest, RoundingUDiv) { for (uint64_t Ai = 1; Ai <= 255; Ai++) { APInt A(8, Ai); diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index 48de0889fdf33..027d6379af26b 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -553,17 +553,11 @@ TEST(KnownBitsTest, BinaryExhaustive) { checkCorrectnessOnlyBinary); testBinaryOpExhaustive( KnownBits::mulhs, - [](const APInt &N1, const APInt &N2) { - unsigned Bits = N1.getBitWidth(); - return (N1.sext(2 * Bits) * N2.sext(2 * Bits)).extractBits(Bits, Bits); - }, + [](const APInt &N1, const APInt &N2) { return APIntOps::mulhs(N1, N2); }, checkCorrectnessOnlyBinary); testBinaryOpExhaustive( KnownBits::mulhu, - [](const APInt &N1, const APInt &N2) { - unsigned Bits = N1.getBitWidth(); - return (N1.zext(2 * Bits) * N2.zext(2 * Bits)).extractBits(Bits, Bits); - }, + [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); }, checkCorrectnessOnlyBinary); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 2f32d9a26e775..0f2811c9e30aa 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -434,9 +434,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, // Invoke the constant fold helper again to calculate the 'high' result. Attribute highAttr = constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { - unsigned bitWidth = a.getBitWidth(); - APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); - return fullProduct.extractBits(bitWidth, bitWidth); + return llvm::APIntOps::mulhs(a, b); }); assert(highAttr && "Unexpected constant-folding failure"); @@ -491,9 +489,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor, // Invoke the constant fold helper again to calculate the 'high' result. Attribute highAttr = constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { - unsigned bitWidth = a.getBitWidth(); - APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); - return fullProduct.extractBits(bitWidth, bitWidth); + return llvm::APIntOps::mulhu(a, b); }); assert(highAttr && "Unexpected constant-folding failure"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 5f47cff71cba0..b1acfd1a2abed 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -250,14 +250,11 @@ struct MulExtendedFold final : OpRewritePattern { auto highBits = constFoldBinaryOp( {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) { - unsigned bitWidth = a.getBitWidth(); - APInt c; if (IsSigned) { - c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); + return llvm::APIntOps::mulhs(a, b); } else { - c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); + return llvm::APIntOps::mulhu(a, b); } - return c.extractBits(bitWidth, bitWidth); // Extract high result }); if (!highBits)