Skip to content

[mlir][Arith] ValueBoundsOpInterface: Support arith.select #87870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

This commit adds a ValueBoundsOpInterface implementation for arith.select. The implementation is almost identical to scf.if (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.

Note: This is a re-upload of #86383.

This commit adds a `ValueBoundsOpInterface` implementation for
`arith.select`. The implementation is almost identical to `scf.if`
(#85895), but there is one special case: if the condition is a shaped
value, the selection is applied element-wise and the result shape can be
inferred from either operand.
@llvmbot
Copy link
Member

llvmbot commented Apr 6, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a ValueBoundsOpInterface implementation for arith.select. The implementation is almost identical to scf.if (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand.

Note: This is a re-upload of #86383.


Full diff: https://github.com/llvm/llvm-project/pull/87870.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp (+70)
  • (modified) mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir (+31)
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 90895e381c74b5..f0d43808bc45df 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -75,6 +75,75 @@ struct MulIOpInterface
   }
 };
 
+struct SelectOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+                                                   SelectOp> {
+
+  static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+                             ValueBoundsConstraintSet &cstr) {
+    Value value = selectOp.getResult();
+    Value condition = selectOp.getCondition();
+    Value trueValue = selectOp.getTrueValue();
+    Value falseValue = selectOp.getFalseValue();
+
+    if (isa<ShapedType>(condition.getType())) {
+      // If the condition is a shaped type, the condition is applied
+      // element-wise. All three operands must have the same shape.
+      cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+      return;
+    }
+
+    // Populate constraints for the true/false values (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(trueValue, dim);
+    cstr.populateConstraints(falseValue, dim);
+    auto boundsBuilder = cstr.bound(value);
+    if (dim)
+      boundsBuilder[*dim];
+
+    // Compare yielded values.
+    // If trueValue <= falseValue:
+    // * result <= falseValue
+    // * result >= trueValue
+    if (cstr.compare(trueValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     falseValue, dim)) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+      } else {
+        cstr.bound(value) >= trueValue;
+        cstr.bound(value) <= falseValue;
+      }
+    }
+    // If falseValue <= trueValue:
+    // * result <= trueValue
+    // * result >= falseValue
+    if (cstr.compare(falseValue, dim,
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     trueValue, dim)) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+      } else {
+        cstr.bound(value) >= falseValue;
+        cstr.bound(value) <= trueValue;
+      }
+    }
+  }
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), dim, cstr);
+  }
+};
 } // namespace
 } // namespace arith
 } // namespace mlir
@@ -86,5 +155,6 @@ void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
     arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
     arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
     arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
+    arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
index 83d5f1c9c9e86c..8fb3ba1a1eccef 100644
--- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
@@ -74,3 +74,34 @@ func.func @arith_const() -> index {
   %0 = "test.reify_bound"(%c5) : (index) -> (index)
   return %0 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @arith_select(
+func.func @arith_select(%c: i1) -> (index, index) {
+  // CHECK: arith.constant 5 : index
+  %c5 = arith.constant 5 : index
+  // CHECK: arith.constant 9 : index
+  %c9 = arith.constant 9 : index
+  %r = arith.select %c, %c5, %c9 : index
+  // CHECK: %[[c5:.*]] = arith.constant 5 : index
+  // CHECK: %[[c10:.*]] = arith.constant 10 : index
+  %0 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+  %1 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+  // CHECK: return %[[c5]], %[[c10]]
+  return %0, %1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @arith_select_elementwise(
+//  CHECK-SAME:     %[[a:.*]]: tensor<?xf32>, %[[b:.*]]: tensor<?xf32>, %[[c:.*]]: tensor<?xi1>)
+func.func @arith_select_elementwise(%a: tensor<?xf32>, %b: tensor<?xf32>, %c: tensor<?xi1>) -> index {
+  %r = arith.select %c, %a, %b : tensor<?xi1>, tensor<?xf32>
+  // CHECK: %[[c0:.*]] = arith.constant 0 : index
+  // CHECK: %[[dim:.*]] = tensor.dim %[[a]], %[[c0]]
+  %0 = "test.reify_bound"(%r) {type = "EQ", dim = 0}
+      : (tensor<?xf32>) -> (index)
+  // CHECK: return %[[dim]]
+  return %0 : index
+}

@matthias-springer matthias-springer merged commit c459a36 into main Apr 7, 2024
@matthias-springer matthias-springer deleted the users/matthias-springer/value_bounds_arith_select2 branch April 7, 2024 00:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants