diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 8286d8f315bd6..25c27620bbba7 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -587,6 +587,8 @@ def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator]> { static StringRef getUpperBoundsMapAttrName() { return "upperBoundsMap"; } static StringRef getStepsAttrName() { return "steps"; } }]; + + let hasCanonicalizer = 1; } def AffinePrefetchOp : Affine_Op<"prefetch"> { diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 27f4450924b66..70ee22484cf61 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2614,6 +2614,33 @@ static LogicalResult verify(AffineVectorStoreOp op) { return success(); } +namespace { +/// This pattern removes affine.parallel ops with no induction variables +struct AffineParallelRank0LoopRemover + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineParallelOp op, + PatternRewriter &rewriter) const override { + // Check that there are no induction variables + if (op.lowerBoundsMap().getNumResults() != 0) + return failure(); + // Remove the affine.parallel wrapper, retain the body in the same location + auto &parentOps = rewriter.getInsertionBlock()->getOperations(); + auto ¶llelBodyOps = op.region().front().getOperations(); + parentOps.splice(mlir::Block::iterator(op), parallelBodyOps, + parallelBodyOps.begin(), std::prev(parallelBodyOps.end())); + rewriter.eraseOp(op); + return success(); + } +}; +} // end anonymous namespace + +void AffineParallelOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index 5c7fba52976a6..bc7bb6c9d2375 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -604,3 +604,18 @@ func @drop_duplicate_bounds(%N : index) { } return } + +// ----- + +// CHECK: func @remove_rank0_affine_parallel(%[[OUT:.*]]: memref) +func @remove_rank0_affine_parallel(%out: memref) { + // CHECK-NEXT: %[[CST:.*]] = constant + %cst = constant 0.0 : f32 + // CHECK-NEXT: affine.store %[[CST]], %[[OUT]][] : memref + affine.parallel () = () to () { + affine.parallel () = () to () { + affine.store %cst, %out[] : memref + } + } + return +}