Skip to content

[mlir] Add arith-int-range-narrowing pass #112404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Nov 5, 2024
Merged

Conversation

Hardcode84
Copy link
Contributor

This pass intended to narrow integer calculations to the specific bitwidth, using IntegerRangeAnalysis.
We already have the arith-int-narrowing pass, but it mostly only doing local analysis, while IntegerRangeAnalysis analyses entire program. They ideally should be unified, but it's a task for the future.

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

This pass intended to narrow integer calculations to the specific bitwidth, using IntegerRangeAnalysis.
We already have the arith-int-narrowing pass, but it mostly only doing local analysis, while IntegerRangeAnalysis analyses entire program. They ideally should be unified, but it's a task for the future.


Full diff: https://github.com/llvm/llvm-project/pull/112404.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (+6)
  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+19)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+260)
  • (added) mlir/test/Dialect/Arith/int-range-narrowing.mlir (+57)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index aee64475171a43..ce8e819708b190 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -73,6 +73,12 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
+/// Add patterns for int range based narrowing.
+void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+                                       DataFlowSolver &solver,
+                                       unsigned targetBitwidth);
+
+// TODO: merge these two narrowing passes.
 /// Add patterns for integer bitwidth narrowing.
 void populateArithIntNarrowingPatterns(RewritePatternSet &patterns,
                                        const ArithIntNarrowingOptions &options);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 1517f71f1a7c90..8c565f6489638e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -50,6 +50,25 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   ];
 }
 
+def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
+  let summary = "Reduce integer operations bitwidth based on integer range analysis";
+  let description = [{
+    This pass runs integer range analysis and tries to narrow arith ops to the
+    specified bitwidth based on its results.
+  }];
+
+  let options = [
+    Option<"targetBitwidth", "target-bitwidth", "unsigned",
+           /*default=*/"32", "Target bitwidth this pass will try to narrow to">,
+  ];
+
+  // Explicitly depend on "arith" because this pass could create operations in
+  // `arith` out of thin air in some cases.
+  let dependentDialects = [
+    "::mlir::arith::ArithDialect"
+  ];
+}
+
 def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
   let summary = "Emulate operations on unsupported floats with extf/truncf";
   let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 521138c1f6f4cd..4e7b44c939e796 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -15,8 +15,10 @@
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -24,6 +26,9 @@
 namespace mlir::arith {
 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+
+#define GEN_PASS_DEF_ARITHINTRANGENARROWING
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
 } // namespace mlir::arith
 
 using namespace mlir;
@@ -184,6 +189,223 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
   DataFlowSolver &solver;
 };
 
+static Type checkArithType(Type type, unsigned targetBitwidth) {
+  type = getElementTypeOrSelf(type);
+  if (isa<IndexType>(type))
+    return type;
+
+  if (auto intType = dyn_cast<IntegerType>(type))
+    if (intType.getWidth() > targetBitwidth)
+      return type;
+
+  return nullptr;
+}
+
+static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return nullptr;
+
+  Type type;
+  for (auto range :
+       {ValueRange(op->getOperands()), ValueRange(op->getResults())}) {
+    for (Value val : range) {
+      if (!type) {
+        type = val.getType();
+        continue;
+      } else if (type != val.getType()) {
+        return nullptr;
+      }
+    }
+  }
+
+  return checkArithType(type, targetBitwidth);
+}
+
+static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
+                                                         ValueRange results) {
+  std::optional<ConstantIntRanges> ret;
+  for (Value value : results) {
+    auto *maybeInferredRange =
+        solver.lookupState<IntegerValueRangeLattice>(value);
+    if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+      return std::nullopt;
+
+    const ConstantIntRanges &inferredRange =
+        maybeInferredRange->getValue().getValue();
+
+    if (!ret) {
+      ret = inferredRange;
+    } else {
+      ret = ret->rangeUnion(inferredRange);
+    }
+  }
+  return ret;
+}
+
+static Type getTargetType(Type srcType, unsigned targetBitwidth) {
+  auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
+  if (auto shaped = dyn_cast<ShapedType>(srcType))
+    return shaped.clone(dstType);
+
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  return dstType;
+}
+
+static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
+                       APInt umin, APInt umax) {
+  auto sge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sge(val2);
+  };
+  auto sle = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sle(val2);
+  };
+  auto uge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.uge(val2);
+  };
+  auto ule = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.ule(val2);
+  };
+  return sge(range.smin(), smin) && sle(range.smax(), smax) &&
+         uge(range.umin(), umin) && ule(range.umax(), umax);
+}
+
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+  Type srcType = src.getType();
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  assert(dstType.isIntOrIndex() && "Invalid dst type");
+  if (srcType == dstType)
+    return src;
+
+  if (isa<IndexType>(srcType) || isa<IndexType>(dstType))
+    return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+
+  auto srcInt = cast<IntegerType>(srcType);
+  auto dstInt = cast<IntegerType>(dstType);
+  if (dstInt.getWidth() < srcInt.getWidth()) {
+    return builder.create<arith::TruncIOp>(loc, dstType, src);
+  } else {
+    return builder.create<arith::ExtUIOp>(loc, dstType, src);
+  }
+}
+
+struct NarrowElementwise final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
+      : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
+        targetBitwidth(target) {}
+
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    Type srcType = checkElementwiseOpType(op, targetBitwidth);
+    if (!srcType)
+      return failure();
+
+    std::optional<ConstantIntRanges> range =
+        getOperandsRange(solver, op->getResults());
+    if (!range)
+      return failure();
+
+    // We are truncating op args to the desired bitwidth before the op and then
+    // extending op results back to the original width after.
+    // extui and exti will produce different results for negative values, so
+    // limit signed range to non-negative values.
+    auto smin = APInt::getZero(targetBitwidth);
+    auto smax = APInt::getSignedMaxValue(targetBitwidth);
+    auto umin = APInt::getMinValue(targetBitwidth);
+    auto umax = APInt::getMaxValue(targetBitwidth);
+    if (!checkRange(*range, smin, smax, umin, umax))
+      return failure();
+
+    Type targetType = getTargetType(srcType, targetBitwidth);
+    if (targetType == srcType)
+      return failure();
+
+    Location loc = op->getLoc();
+    IRMapping mapping;
+    for (Value arg : op->getOperands()) {
+      Value newArg = doCast(rewriter, loc, arg, targetType);
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newOp = rewriter.clone(*op, mapping);
+    rewriter.modifyOpInPlace(newOp, [&]() {
+      for (OpResult res : newOp->getResults()) {
+        res.setType(targetType);
+      }
+    });
+    SmallVector<Value> newResults;
+    for (Value res : newOp->getResults())
+      newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  unsigned targetBitwidth;
+};
+
+struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
+  NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
+             unsigned target)
+      : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
+
+  LogicalResult matchAndRewrite(arith::CmpIOp op,
+                                PatternRewriter &rewriter) const override {
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+
+    Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+    if (!srcType)
+      return failure();
+
+    std::optional<ConstantIntRanges> range =
+        getOperandsRange(solver, {lhs, rhs});
+    if (!range)
+      return failure();
+
+    auto smin = APInt::getSignedMinValue(targetBitwidth);
+    auto smax = APInt::getSignedMaxValue(targetBitwidth);
+    auto umin = APInt::getMinValue(targetBitwidth);
+    auto umax = APInt::getMaxValue(targetBitwidth);
+    if (!checkRange(*range, smin, smax, umin, umax))
+      return failure();
+
+    Type targetType = getTargetType(srcType, targetBitwidth);
+    if (targetType == srcType)
+      return failure();
+
+    Location loc = op->getLoc();
+    IRMapping mapping;
+    for (Value arg : op->getOperands()) {
+      Value newArg = doCast(rewriter, loc, arg, targetType);
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newOp = rewriter.clone(*op, mapping);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  unsigned targetBitwidth;
+};
+
 struct IntRangeOptimizationsPass
     : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -208,6 +430,32 @@ struct IntRangeOptimizationsPass
       signalPassFailure();
   }
 };
+
+struct IntRangeNarrowingPass
+    : public arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
+  using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<IntegerRangeAnalysis>();
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+
+    DataFlowListener listener(solver);
+
+    RewritePatternSet patterns(ctx);
+    populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
+
+    GreedyRewriteConfig config;
+    config.listener = &listener;
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+      signalPassFailure();
+  }
+};
 } // namespace
 
 void mlir::arith::populateIntRangeOptimizationsPatterns(
@@ -216,6 +464,18 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
                DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
 }
 
+void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+                                                    DataFlowSolver &solver,
+                                                    unsigned targetBitwidth) {
+  // Cmpi uses args ranges instead of results, run it with higher benefit,
+  // as its argumens can be potentially replaced.
+  patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
+                           targetBitwidth);
+
+  patterns.add<NarrowElementwise>(patterns.getContext(), solver,
+                                  targetBitwidth);
+}
+
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
   return std::make_unique<IntRangeOptimizationsPass>();
 }
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
new file mode 100644
index 00000000000000..d85cb3384061be
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32"  %s | FileCheck %s
+
+// Do not truncate negative values
+// CHECK-LABEL: func @test_addi_neg
+//       CHECK:  %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+//       CHECK:  return %[[RES]] : index
+func.func @test_addi_neg() -> index {
+  %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = -1 : index, smin = -1 : index, smax = 0 : index } : index
+  %2 = arith.addi %0, %1 : index
+  return %2 : index
+}
+
+// CHECK-LABEL: func @test_addi
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index
+//       CHECK:  return %[[RES_CASTED]] : index
+func.func @test_addi() -> index {
+  %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+  %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : index
+  %2 = arith.addi %0, %1 : index
+  return %2 : index
+}
+
+
+// CHECK-LABEL: func @test_addi_i64
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64
+//       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64
+//       CHECK:  return %[[RES_CASTED]] : i64
+func.func @test_addi_i64() -> i64 {
+  %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
+  %1 = test.with_bounds { umin = 6 : i64, umax = 7 : i64, smin = 6 : i64, smax = 7 : i64 } : i64
+  %2 = arith.addi %0, %1 : i64
+  return %2 : i64
+}
+
+// CHECK-LABEL: func @test_cmpi
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+//       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  return %[[RES]] : i1
+func.func @test_cmpi() -> i1 {
+  %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %2 = arith.cmpi slt, %0, %1 : index
+  return %2 : i1
+}

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have the arith-int-narrowing pass, but it mostly only doing local analysis, while IntegerRangeAnalysis analyses entire program. They ideally should be unified, but it's a task for the future.

AFAIK, arith-int-narrowing doesn't have any use(r)s and could be dropped. I think it would be better to replace it with the new implementation from this PR.

@Hardcode84
Copy link
Contributor Author

We already have the arith-int-narrowing pass, but it mostly only doing local analysis, while IntegerRangeAnalysis analyses entire program. They ideally should be unified, but it's a task for the future.

AFAIK, arith-int-narrowing doesn't have any use(r)s and could be dropped. I think it would be better to replace it with the new implementation from this PR.

I've added some tests from the old arith-int-narrowing pass, but it cannot be replaced 1-to-1 yet, vector support need more work (#112292) and for some cases this new pass chooses more conservative truncation.

I would prefer to merge this one in its current state and improve incrementally.

@kuhar
Copy link
Member

kuhar commented Oct 16, 2024

My suggestion is not to have two passes that effectively do the same thing, but one unused/unmaintained and one incomplete. I'd rather delete the old one, even if we don't have a full replacement from day 0.

@kuhar
Copy link
Member

kuhar commented Oct 16, 2024

If you would like to eventually support some rewrites from the original pass, we can open a tracking issue and capture the relevant test cases there (permalinks to the current code).

@Hardcode84
Copy link
Contributor Author

My suggestion is not to have two passes that effectively do the same thing, but one unused/unmaintained and one incomplete. I'd rather delete the old one, even if we don't have a full replacement from day 0.

Removing pass is easy, my only concern it will break someone downstream we don't know about.

@kuhar
Copy link
Member

kuhar commented Oct 16, 2024

We can start a discourse thread with a deprecation notice. But this should be easy for downstream users to copy-paste if they need it -- the code is self-contained.

@joker-eph
Copy link
Collaborator

Removing pass is easy, my only concern it will break someone downstream we don't know about.

They can complain then and reinstate the pass, or fork it downstream, or explain why they can't migrate to the new one?

@Hardcode84
Copy link
Contributor Author

Ok, removed the old pass. I still want to use new pass name, as any potential users will at least know something had changed.

This pass intended to narrow integer calculations to the specific bitwidth, using `IntegerRangeAnalysis`.
We already have the `arith-int-narrowing` pass, but it mostly only doing local analysis, while `IntegerRangeAnalysis` analyses entire program.
They ideally should be unified in the future, but it's a task for the future.
@Hardcode84
Copy link
Contributor Author

updated, added vector type support

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just circling back two a couple old comments

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Hardcode84 Hardcode84 merged commit 9f0f6df into llvm:main Nov 5, 2024
8 checks passed
@Hardcode84 Hardcode84 deleted the truncate-equiv branch November 5, 2024 11:30
PhilippRados pushed a commit to PhilippRados/llvm-project that referenced this pull request Nov 6, 2024
This pass intended to narrow integer calculations to the specific
bitwidth, using `IntegerRangeAnalysis`.
We already have the `arith-int-narrowing` pass, but it mostly only doing
local analysis, while `IntegerRangeAnalysis` analyses entire program.
They ideally should be unified, but it's a task for the future.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants