diff --git a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h index 2abbabc5bb286..5d4774861bdfd 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h @@ -9,11 +9,25 @@ #ifndef MLIR_DIALECT_AFFINE_IR_VALUEBOUNDSOPINTERFACEIMPL_H #define MLIR_DIALECT_AFFINE_IR_VALUEBOUNDSOPINTERFACEIMPL_H +#include "mlir/Support/LogicalResult.h" + namespace mlir { class DialectRegistry; +class Value; namespace affine { void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); + +/// Compute whether the given values are equal. Return "failure" if equality +/// could not be determined. `value1`/`value2` must be index-typed. +/// +/// This function is similar to `ValueBoundsConstraintSet::areEqual`. To work +/// around limitations in `FlatLinearConstraints`, this function fully composes +/// `value1` and `value2` (if they are the result of affine.apply ops) before +/// populating the constraint set. The folding/composing logic can see +/// opportunities for simplifications that the constraint set implementation +/// cannot see. +FailureOr fullyComposeAndCheckIfEqual(Value value1, Value value2); } // namespace affine } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index 97dd70e4f1d2b..d47c8eb8ccb42 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -27,12 +27,22 @@ struct AffineApplyOpInterface assert(applyOp.getAffineMap().getNumResults() == 1 && "expected single result"); + // Fully compose this affine.apply with other ops because the folding logic + // can see opportunities for simplifying the affine map that + // `FlatLinearConstraints` can currently not see. + AffineMap map = applyOp.getAffineMap(); + SmallVector operands = llvm::to_vector(applyOp.getOperands()); + fullyComposeAffineMapAndOperands(&map, &operands); + // Align affine map result with dims/symbols in the constraint set. - AffineExpr expr = applyOp.getAffineMap().getResult(0); - SmallVector dimReplacements = llvm::to_vector(llvm::map_range( - applyOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); })); - SmallVector symReplacements = llvm::to_vector(llvm::map_range( - applyOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); })); + AffineExpr expr = map.getResult(0); + SmallVector dimReplacements, symReplacements; + for (int64_t i = 0, e = map.getNumDims(); i < e; ++i) + dimReplacements.push_back(cstr.getExpr(operands[i])); + for (int64_t i = map.getNumDims(), + e = map.getNumDims() + map.getNumSymbols(); + i < e; ++i) + symReplacements.push_back(cstr.getExpr(operands[i])); AffineExpr bound = expr.replaceDimsAndSymbols(dimReplacements, symReplacements); cstr.bound(value) == bound; @@ -92,3 +102,30 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels( AffineMinOp::attachInterface(*ctx); }); } + +FailureOr mlir::affine::fullyComposeAndCheckIfEqual(Value value1, + Value value2) { + assert(value1.getType().isIndex() && "expected index type"); + assert(value2.getType().isIndex() && "expected index type"); + + // Subtract the two values/dimensions from each other. If the result is 0, + // both are equal. + Builder b(value1.getContext()); + AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, + b.getAffineDimExpr(0) - b.getAffineDimExpr(1)); + // Fully compose the affine map with other ops because the folding logic + // can see opportunities for simplifying the affine map that + // `FlatLinearConstraints` can currently not see. + SmallVector mapOperands; + mapOperands.push_back(value1); + mapOperands.push_back(value2); + affine::fullyComposeAffineMapAndOperands(&map, &mapOperands); + ValueDimList valueDims; + for (Value v : mapOperands) + valueDims.push_back({v, std::nullopt}); + FailureOr bound = ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::EQ, map, valueDims); + if (failed(bound)) + return failure(); + return *bound == 0; +} diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir index 338c48c5b210b..8acf358c887a9 100644 --- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir @@ -58,3 +58,35 @@ func.func @affine_min_lb(%a: index) -> (index) { %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index) return %2 : index } + +// ----- + +// CHECK-LABEL: func @composed_affine_apply( +// CHECK: %[[cst:.*]] = arith.constant -8 : index +// CHECK: return %[[cst]] +func.func @composed_affine_apply(%i1 : index) -> (index) { + // The ValueBoundsOpInterface implementation of affine.apply fully composes + // the affine map (and its operands) with other affine.apply ops drawn from + // its operands before adding it to the constraint set. This is to work + // around a limitation in `FlatLinearConstraints`, which can currently not + // compute a constant bound for %s. (The affine map simplification logic can + // simplify %s to -8.) + %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1) + %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1) + %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3] + %reified = "test.reify_constant_bound"(%s) {type = "EQ"} : (index) -> (index) + return %reified : index +} + + +// ----- + +// Test for affine::fullyComposeAndCheckIfEqual +func.func @composed_are_equal(%i1 : index) { + %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1) + %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1) + %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3] + // expected-remark @below{{different}} + "test.are_equal"(%i2, %i3) {compose} : (index, index) -> () + return +} diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index ad017cef1b9ba..6e3c3dff759a2 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/Arith/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -186,8 +187,14 @@ static LogicalResult testEquality(func::FuncOp funcOp) { op->emitOpError("invalid op"); return WalkResult::skip(); } - FailureOr equal = ValueBoundsConstraintSet::areEqual( - op->getOperand(0), op->getOperand(1)); + FailureOr equal = failure(); + if (op->hasAttr("compose")) { + equal = affine::fullyComposeAndCheckIfEqual(op->getOperand(0), + op->getOperand(1)); + } else { + equal = ValueBoundsConstraintSet::areEqual(op->getOperand(0), + op->getOperand(1)); + } if (failed(equal)) { op->emitError("could not determine equality"); } else if (*equal) {