From a00b999e3b46dbb184c82c773b7455c071c2ec38 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 6 Dec 2023 12:45:51 +0900 Subject: [PATCH] [mlir][SCF] Retire SCF-specific `to_memref`/`to_tensor` canonicalization patterns The partial bufferization framework has been replaced with One-Shot Bufferize. SCF-specific canonicalization patterns for `to_memref`/`to_tensor` are no longer needed. --- mlir/lib/Dialect/SCF/IR/CMakeLists.txt | 3 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 132 +----------------- mlir/test/Dialect/SCF/canonicalize.mlir | 50 ------- .../llvm-project-overlay/mlir/BUILD.bazel | 1 - 4 files changed, 4 insertions(+), 182 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt index 9882b843c285e..423e1c3e1e042 100644 --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -11,12 +11,13 @@ add_mlir_dialect_library(MLIRSCFDialect LINK_LIBS PUBLIC MLIRArithDialect - MLIRBufferizationDialect MLIRControlFlowDialect + MLIRDialectUtils MLIRFunctionInterfaces MLIRIR MLIRLoopLikeInterface MLIRSideEffectInterfaces + MLIRTensorDialect MLIRValueBoundsOpInterface ) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 3b55704c4ea07..cf807a2adc10e 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" @@ -1082,139 +1081,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern { } }; -/// Canonicalize the iter_args of an scf::ForOp that involve a -/// `bufferization.to_tensor` and for which only the last loop iteration is -/// actually visible outside of the loop. The canonicalization looks for a -/// pattern such as: -/// ``` -/// %t0 = ... : tensor_type -/// %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) { -/// ... -/// // %m is either buffer_cast(%bb00) or defined above the loop -/// %m... : memref_type -/// ... // uses of %m with potential inplace updates -/// %new_tensor = bufferization.to_tensor %m : memref_type -/// ... -/// scf.yield %new_tensor : tensor_type -/// } -/// ``` -/// -/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a -/// `%m = buffer_cast %bb0` op that feeds into the yielded -/// `bufferization.to_tensor` op. -/// -/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded, -/// occurs between `bufferization.to_tensor and yield then the value %0 -/// visible outside of the loop is the last `bufferization.to_tensor` -/// produced in the loop. -/// -/// For now, we approximate the absence of aliasing by only supporting the case -/// when the bufferization.to_tensor is the operation immediately preceding -/// the yield. -// -/// The canonicalization rewrites the pattern as: -/// ``` -/// // %m is either a buffer_cast or defined above -/// %m... : memref_type -/// scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) { -/// ... // uses of %m with potential inplace updates -/// scf.yield %bb0: tensor_type -/// } -/// %0 = bufferization.to_tensor %m : memref_type -/// ``` -/// -/// A later bbArg canonicalization will further rewrite as: -/// ``` -/// // %m is either a buffer_cast or defined above -/// %m... : memref_type -/// scf.for ... { // no iter_args -/// ... // uses of %m with potential inplace updates -/// } -/// %0 = bufferization.to_tensor %m : memref_type -/// ``` -struct LastTensorLoadCanonicalization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const override { - assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() && - "unexpected multiple blocks"); - - Location loc = forOp.getLoc(); - DenseMap replacements; - for (BlockArgument bbArg : forOp.getRegionIterArgs()) { - unsigned idx = bbArg.getArgNumber() - /*numIv=*/1; - auto yieldOp = - cast(forOp.getRegion().front().getTerminator()); - Value yieldVal = yieldOp->getOperand(idx); - auto tensorLoadOp = yieldVal.getDefiningOp(); - bool isTensor = llvm::isa(bbArg.getType()); - - bufferization::ToMemrefOp tensorToMemref; - // Either bbArg has no use or it has a single buffer_cast use. - if (bbArg.hasOneUse()) - tensorToMemref = - dyn_cast(*bbArg.getUsers().begin()); - if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref)) - continue; - // If tensorToMemref is present, it must feed into the `ToTensorOp`. - if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref) - continue; - // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp` - // must be before `ToTensorOp` in the block so that the lastWrite - // property is not subject to additional side-effects. - // For now, we only support the case when ToTensorOp appears - // immediately before the terminator. - if (tensorLoadOp->getNextNode() != yieldOp) - continue; - - // Clone the optional tensorToMemref before forOp. - if (tensorToMemref) { - rewriter.setInsertionPoint(forOp); - rewriter.replaceOpWithNewOp( - tensorToMemref, tensorToMemref.getMemref().getType(), - tensorToMemref.getTensor()); - } - - // Clone the tensorLoad after forOp. - rewriter.setInsertionPointAfter(forOp); - Value newTensorLoad = rewriter.create( - loc, tensorLoadOp.getMemref()); - Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1); - replacements.insert(std::make_pair(forOpResult, newTensorLoad)); - - // Make the terminator just yield the bbArg, the old tensorLoadOp + the - // old bbArg (that is now directly yielded) will canonicalize away. - rewriter.startRootUpdate(yieldOp); - yieldOp.setOperand(idx, bbArg); - rewriter.finalizeRootUpdate(yieldOp); - } - if (replacements.empty()) - return failure(); - - // We want to replace a subset of the results of `forOp`. rewriter.replaceOp - // replaces the whole op and erase it unconditionally. This is wrong for - // `forOp` as it generally contains ops with side effects. - // Instead, use `rewriter.replaceOpWithIf`. - SmallVector newResults; - newResults.reserve(forOp.getNumResults()); - for (Value v : forOp.getResults()) { - auto it = replacements.find(v); - newResults.push_back((it != replacements.end()) ? it->second : v); - } - unsigned idx = 0; - rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) { - return op.get() != newResults[idx++]; - }); - return success(); - } -}; } // namespace void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add( + context); } std::optional ForOp::getConstantStep() { diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 9dbf8d5dab11a..41e028028616a 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -773,56 +773,6 @@ func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) { // ----- -func.func private @process(%0 : memref<128x128xf32>) -func.func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32> - -// CHECK-LABEL: last_value -// CHECK-SAME: %[[T0:[0-9a-z]*]]: tensor<128x128xf32> -// CHECK-SAME: %[[T1:[0-9a-z]*]]: tensor<128x128xf32> -// CHECK-SAME: %[[T2:[0-9a-z]*]]: tensor<128x128xf32> -// CHECK-SAME: %[[M0:[0-9a-z]*]]: memref<128x128xf32> -func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>, - %t2: tensor<128x128xf32>, %m0: memref<128x128xf32>, - %lb : index, %ub : index, %step : index) - -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>) -{ - // CHECK-NEXT: %[[M1:.*]] = bufferization.to_memref %[[T1]] : memref<128x128xf32> - // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[BBARG_T2:.*]] = %[[T2]]) -> (tensor<128x128xf32>) { - %0:3 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %t0, %arg2 = %t1, %arg3 = %t2) - -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>) - { - %m1 = bufferization.to_memref %arg2 : memref<128x128xf32> - - // CHECK-NEXT: call @process(%[[M0]]) : (memref<128x128xf32>) -> () - func.call @process(%m0) : (memref<128x128xf32>) -> () - - // CHECK-NEXT: call @process(%[[M1]]) : (memref<128x128xf32>) -> () - func.call @process(%m1) : (memref<128x128xf32>) -> () - - // This does not hoist (fails the bbArg has at most a single check). - // CHECK-NEXT: %[[T:.*]] = func.call @process_tensor(%[[BBARG_T2]]) : (tensor<128x128xf32>) -> memref<128x128xf32> - // CHECK-NEXT: %[[YIELD_T:.*]] = bufferization.to_tensor %[[T:.*]] - %m2 = func.call @process_tensor(%arg3): (tensor<128x128xf32>) -> memref<128x128xf32> - %3 = bufferization.to_tensor %m2 : memref<128x128xf32> - - // All this stuff goes away, incrementally - %1 = bufferization.to_tensor %m0 : memref<128x128xf32> - %2 = bufferization.to_tensor %m1 : memref<128x128xf32> - - // CHECK-NEXT: scf.yield %[[YIELD_T]] : tensor<128x128xf32> - scf.yield %1, %2, %3 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> - - // CHECK-NEXT: } - } - - // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32> - // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32> - // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> - return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> -} - -// ----- - // CHECK-LABEL: fold_away_iter_with_no_use_and_yielded_input // CHECK-SAME: %[[A0:[0-9a-z]*]]: i32 func.func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32, diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 4fb6a50a174c2..2a3ebbba02384 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3994,7 +3994,6 @@ cc_library( deps = [ ":ArithDialect", ":ArithUtils", - ":BufferizationDialect", ":ControlFlowDialect", ":ControlFlowInterfaces", ":DestinationStyleOpInterface",