diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h index 3d3316db6b093..cab997e1aff29 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -63,19 +63,12 @@ void populateEliminateBufferizeMaterializationsPatterns( BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns); /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. -/// If `copyBeforeWrite`, buffers are duplicated and copied before any tensor -/// use that bufferizes to a memory write. /// -/// Note: In the general case, it unsafe to run with `copyBeforeWrite = false` -/// because read-after-write conflicts may materialize during bufferization. -/// `copyBeforeWrite = false` is safe only if the input IR is guaranteed to -/// *not* require any out-of-place bufferization. -/// -/// Note: This function bufferizes ops without utilizing analysis results. It -/// can be used to implement partial bufferization passes. +/// Note: This function does not resolve read-after-write conflicts. Use this +/// function only if it is guaranteed that the input IR can bufferize without +/// additional buffer copies or set "options.copyBeforeWrite = true". The +/// general bufferization entry point is `runOneShotBufferize`. LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, - bool copyBeforeWrite = true, - const OpFilter *opFilter = nullptr, BufferizationStatistics *statistics = nullptr); /// Bufferize the signature of `block` and its callers (i.e., ops that have the @@ -94,6 +87,9 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options); +/// Return `BufferizationOptions` such that the `bufferizeOp` behaves like the +/// old (deprecated) partial, dialect conversion-based bufferization passes. A +/// copy will be inserted before every buffer write. BufferizationOptions getPartialBufferizationOptions(); } // namespace bufferization diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 045dea5d2b85f..f2125feeda541 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -383,11 +383,9 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { DenseSet &toMemrefOps, SmallVector &worklist, const BufferizationOptions &options, - const OpFilter *opFilter, BufferizationStatistics *statistics) : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps), - worklist(worklist), analysisState(options), opFilter(opFilter), - statistics(statistics) { + worklist(worklist), analysisState(options), statistics(statistics) { setListener(this); } @@ -424,7 +422,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { // Skip ops that are not allowed to be bufferized. auto const &options = analysisState.getOptions(); - if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op))) + if (!options.isOpAllowed(op)) return; // Add op to worklist. @@ -445,9 +443,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { /// bufferization options. const AnalysisState analysisState; - /// An extra op filter for bufferization. - const OpFilter *opFilter; - /// Bufferization statistics for debugging. BufferizationStatistics *statistics; }; @@ -455,10 +450,8 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { LogicalResult bufferization::bufferizeOp(Operation *op, const BufferizationOptions &options, - bool copyBeforeWrite, - const OpFilter *opFilter, BufferizationStatistics *statistics) { - if (copyBeforeWrite) { + if (options.copyBeforeWrite) { AnalysisState state(options); if (failed(insertTensorCopies(op, state))) return failure(); @@ -486,7 +479,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op, // Bufferize all ops. BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps, - worklist, options, opFilter, statistics); + worklist, options, statistics); for (unsigned i = 0; i < worklist.size(); ++i) { Operation *nextOp = worklist[i]; // Skip ops that were erased. @@ -496,7 +489,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op, auto bufferizableOp = options.dynCastBufferizableOp(nextOp); if (!bufferizableOp) continue; - if (opFilter && !opFilter->isOpAllowed(nextOp)) + if (!options.isOpAllowed(nextOp)) continue; // Skip ops that no longer have tensor semantics. if (!hasTensorSemantics(nextOp)) @@ -558,8 +551,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op, // Continue ops that are not allowed. if (!options.isOpAllowed(op)) continue; - if (opFilter && !opFilter->isOpAllowed(op)) - continue; // Ops without any uses and no side effects will fold away. if (op->getUses().empty() && isMemoryEffectFree(op)) continue; @@ -662,6 +653,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, BufferizationOptions bufferization::getPartialBufferizationOptions() { BufferizationOptions options; options.allowUnknownOps = true; + options.copyBeforeWrite = true; options.enforceAliasingInvariants = false; options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, const BufferizationOptions &options) { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index 1c85dbb5688be..8887b7c57933e 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -40,8 +40,8 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include #include +#include #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -1323,6 +1323,5 @@ bufferization::runOneShotBufferize(Operation *op, } if (options.testAnalysisOnly) return success(); - return bufferizeOp(op, options, /*copyBeforeWrite=*/options.copyBeforeWrite, - /*opFilter=*/nullptr, statistics); + return bufferizeOp(op, options, statistics); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 417f457c8910c..66c123ad2cefc 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -238,7 +238,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) { /// Return the func::FuncOp called by `callOp`. static func::FuncOp getCalledFunction(func::CallOp callOp) { - SymbolRefAttr sym = llvm::dyn_cast_if_present(callOp.getCallableForCallee()); + SymbolRefAttr sym = + llvm::dyn_cast_if_present(callOp.getCallableForCallee()); if (!sym) return nullptr; return dyn_cast_or_null( @@ -426,12 +427,19 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( for (func::FuncOp funcOp : orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - bool copyBeforeWrite = - options.copyBeforeWrite || - llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName()); - if (failed(bufferizeOp(funcOp, options, copyBeforeWrite, - /*opFilter=*/nullptr, statistics))) - return failure(); + + if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) { + // This function was not analyzed and RaW conflicts were not resolved. + // Buffer copies must be inserted before every write. + OneShotBufferizationOptions updatedOptions = options; + updatedOptions.copyBeforeWrite = true; + if (failed(bufferizeOp(funcOp, updatedOptions, statistics))) + return failure(); + } else { + if (failed(bufferizeOp(funcOp, options, statistics))) + return failure(); + } + // Change buffer return types to more precise layout maps. if (options.inferFunctionResultLayout) foldMemRefCasts(funcOp); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 9b5567814a75f..6fca8f82e3566 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -81,23 +81,23 @@ class SparsificationAndBufferizationPass /// and that all required buffer copies were already inserted by /// `insertTensorCopies` in the form of `bufferization.alloc_tensor` ops. LogicalResult runDenseBufferization() { - bufferization::OpFilter denseOpFilter; - denseOpFilter.allowOperation([&](Operation *op) { + bufferization::OneShotBufferizationOptions updatedOptions = + bufferizationOptions; + // Skip all sparse ops. + updatedOptions.opFilter.denyOperation([&](Operation *op) { if (containsSparseTensor(TypeRange(op->getResults())) || containsSparseTensor(TypeRange(op->getOperands()))) - return false; + return true; if (auto funcOp = dyn_cast(op)) { FunctionType funcType = funcOp.getFunctionType(); if (containsSparseTensor(funcType.getInputs()) || containsSparseTensor(funcType.getResults())) - return false; + return true; } - return true; + return false; }); - if (failed(bufferization::bufferizeOp(getOperation(), bufferizationOptions, - /*copyBeforeWrite=*/false, - &denseOpFilter))) + if (failed(bufferization::bufferizeOp(getOperation(), updatedOptions))) return failure(); bufferization::removeBufferizationAttributesInModule(getOperation());