From 0f4b5511a12b0e2c60ba5011310f80dd0b314189 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Thu, 1 May 2025 16:42:35 +0200 Subject: [PATCH 1/7] Add BufferizationState class --- .../Bufferization/IR/BufferizableOpInterface.h | 13 +++++++++++++ .../Bufferization/IR/BufferizableOpInterface.cpp | 4 ++++ 2 files changed, 17 insertions(+) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index cb6ef8bc17220..891a5d9044852 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -578,6 +578,19 @@ class AnalysisState { insideMutuallyExclusiveRegionsCache; }; +/// BufferizationState provides information about the state of the IR during the +/// bufferization process. +class BufferizationState { +public: + /// Get the cached symbol tables. + /// The user is expected to update / invalidate the cached symbol tables if + /// the bufferized operation have the Symbol or SymbolTable traits. + SymbolTableCollection &getSymbolTables(); + +private: + SymbolTableCollection symbolTables; +}; + /// Create an AllocTensorOp for the given shaped value (memref or tensor). /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with /// undefined contents is allocated. diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 1fc34051680f1..14fa4c1ed8159 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -125,6 +125,10 @@ void AnalysisState::resetCache() { insideMutuallyExclusiveRegionsCache.clear(); } +SymbolTableCollection &BufferizationState::getSymbolTables() { + return symbolTables; +} + Region *bufferization::getNextEnclosingRepetitiveRegion( Region *region, const BufferizationOptions &options) { assert(isRepetitiveRegion(region, options) && "expected repetitive region"); From d461cff0cad4722f2977e08ba097d80a64776f8b Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Thu, 1 May 2025 16:43:28 +0200 Subject: [PATCH 2/7] Add BufferizationState as argument to bufferize method --- .../IR/BufferizableOpInterface.td | 3 +- .../Bufferization/IR/BufferizationOps.td | 15 ++++-- .../Bufferization/Transforms/BufferUtils.h | 1 + .../Bufferization/Transforms/Bufferize.h | 1 + .../Transforms/OneShotAnalysis.h | 1 + .../Transforms/OneShotModuleBufferize.h | 4 +- .../Dialect/Linalg/Transforms/Transforms.h | 1 + .../BufferizableOpInterfaceImpl.cpp | 12 +++-- .../Bufferization/IR/BufferizationOps.cpp | 12 +++-- .../BufferizationTransformOps.cpp | 8 +++- .../Bufferization/Transforms/BufferUtils.cpp | 7 +-- .../Bufferization/Transforms/Bufferize.cpp | 10 ++-- .../FuncBufferizableOpInterfaceImpl.cpp | 9 ++-- .../Transforms/OneShotAnalysis.cpp | 9 ++-- .../Transforms/OneShotModuleBufferize.cpp | 12 ++--- .../BufferizableOpInterfaceImpl.cpp | 3 +- .../BufferizableOpInterfaceImpl.cpp | 7 ++- .../Transforms/ConvertToDestinationStyle.cpp | 25 ++++++---- .../BufferizableOpInterfaceImpl.cpp | 17 +++++-- .../BufferizableOpInterfaceImpl.cpp | 27 +++++++---- .../BufferizableOpInterfaceImpl.cpp | 6 ++- .../BufferizableOpInterfaceImpl.cpp | 3 +- .../SparsificationAndBufferizationPass.cpp | 5 +- .../BufferizableOpInterfaceImpl.cpp | 48 ++++++++++++------- .../BufferizableOpInterfaceImpl.cpp | 15 ++++-- 25 files changed, 175 insertions(+), 86 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td index 95022d7d665d2..b599a9f053215 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { /*retType=*/"::llvm::LogicalResult", /*methodName=*/"bufferize", /*args=*/(ins "::mlir::RewriterBase &":$rewriter, - "const ::mlir::bufferization::BufferizationOptions &":$options), + "const ::mlir::bufferization::BufferizationOptions &":$options, + "::mlir::bufferization::BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ llvm_unreachable("bufferize not implemented"); diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 7a1a701bea6dc..dafa4b9b183f2 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", let extraClassDeclaration = [{ LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options); + const BufferizationOptions &options, + BufferizationState &state); bool resultBufferizesToMemoryWrite(OpResult opResult, const AnalysisState &state); @@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp let extraClassDeclaration = [{ LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options); + const BufferizationOptions &options, + BufferizationState &state); bool bufferizesToMemoryRead(OpOperand &opOperand, const AnalysisState &state); @@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor", } LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options); + const BufferizationOptions &options, + BufferizationState &state); }]; } @@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ //===------------------------------------------------------------------===// LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { // to_tensor/to_buffer pairs fold away after bufferization. return success(); } @@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [ } LogicalResult bufferize(RewriterBase &rewriter, - const BufferizationOptions &options); + const BufferizationOptions &options, + BufferizationState &state); }]; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h index e5f3b6d571f43..adeb52cf9d7e6 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h @@ -122,6 +122,7 @@ class BufferPlacementTransformationBase { // Globals are created lazily at the top of the enclosing ModuleOp with pretty // names. Duplicates are avoided. FailureOr getGlobalFor(arith::ConstantOp constantOp, + SymbolTableCollection &symbolTables, uint64_t alignment, Attribute memorySpace = {}); diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h index d5cb8d8eb673c..70e3defee0867 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -45,6 +45,7 @@ struct BufferizationStatistics { /// additional buffer copies or set "options.copyBeforeWrite = true". The /// general bufferization entry point is `runOneShotBufferize`. LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, + BufferizationState &bufferizationState, BufferizationStatistics *statistics = nullptr); /// Bufferize the signature of `block` and its callers (i.e., ops that have the diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h index 673027f76190d..15189d2c1cb87 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, /// Run One-Shot Bufferize on the given op: Analysis + Bufferization LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, + BufferizationState &state, BufferizationStatistics *statistics = nullptr); } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h index 4e5f5e9c730fa..2cf801dd1d951 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -20,6 +20,7 @@ namespace bufferization { struct BufferizationStatistics; class OneShotAnalysisState; struct OneShotBufferizationOptions; +class BufferizationState; /// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in /// `state`. @@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, /// will be inserted only to these FuncOps. llvm::LogicalResult bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, + BufferizationState &state, BufferizationStatistics *statistics = nullptr); /// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. @@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp); llvm::LogicalResult runOneShotModuleBufferize( ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, - BufferizationStatistics *statistics = nullptr); + BufferizationState &state, BufferizationStatistics *statistics = nullptr); } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 4f90fc8831bc6..2eef0a06d0eb4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -30,6 +30,7 @@ namespace mlir { namespace bufferization { class AllocTensorOp; class OneShotAnalysisState; +class BufferizationState; } // namespace bufferization namespace linalg { diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index 5e69a98db8f1e..f646326ffc58f 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -24,7 +24,8 @@ struct ConstantOpInterface : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto constantOp = cast(op); auto type = dyn_cast(constantOp.getType()); @@ -46,7 +47,8 @@ struct ConstantOpInterface // Create global memory segment and replace tensor with memref pointing to // that memory segment. FailureOr globalOp = - getGlobalFor(constantOp, options.bufferAlignment, memorySpace); + getGlobalFor(constantOp, state.getSymbolTables(), + options.bufferAlignment, memorySpace); if (failed(globalOp)) return failure(); memref::GlobalOp globalMemref = *globalOp; @@ -83,7 +85,8 @@ struct IndexCastOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto castOp = cast(op); auto resultTensorType = cast(castOp.getType()); @@ -131,7 +134,8 @@ struct SelectOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto selectOp = cast(op); Location loc = selectOp.getLoc(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index ecd2ef15546a4..91eccb0ab7430 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -149,7 +149,8 @@ void mlir::bufferization::populateDynamicDimSizes( //===----------------------------------------------------------------------===// LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options) { + const BufferizationOptions &options, + BufferizationState &state) { OpBuilder::InsertionGuard g(rewriter); Location loc = getLoc(); @@ -529,7 +530,8 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, //===----------------------------------------------------------------------===// LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options) { + const BufferizationOptions &options, + BufferizationState &state) { FailureOr buffer = getBuffer(rewriter, getTensor(), options); if (failed(buffer)) return failure(); @@ -576,7 +578,8 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand, LogicalResult MaterializeInDestinationOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options) { + const BufferizationOptions &options, + BufferizationState &state) { bool tensorDest = isa(getDest().getType()); Value buffer; if (tensorDest) { @@ -861,7 +864,8 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results, } LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter, - const BufferizationOptions &options) { + const BufferizationOptions &options, + BufferizationState &state) { // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary. (void)foldToBufferToTensorPair(rewriter, *this, options); // Note: The return value of `bufferize` indicates whether there was an error diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index a1d7bb995fc73..8bb7942304274 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -83,6 +83,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, } auto payloadOps = state.getPayloadOps(getTarget()); + BufferizationState bufferizationState; + for (Operation *target : payloadOps) { if (!isa(target)) return emitSilenceableError() << "expected module or function target"; @@ -90,10 +92,12 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, if (options.bufferizeFunctionBoundaries) { if (!moduleOp) return emitSilenceableError() << "expected module target"; - if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) + if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options, + bufferizationState))) return emitSilenceableError() << "bufferization failed"; } else { - if (failed(bufferization::runOneShotBufferize(target, options))) + if (failed(bufferization::runOneShotBufferize(target, options, + bufferizationState))) return emitSilenceableError() << "bufferization failed"; } } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index c2e90764b1335..bb21f642ac077 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -103,8 +103,9 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase( //===----------------------------------------------------------------------===// FailureOr -bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, - Attribute memorySpace) { +bufferization::getGlobalFor(arith::ConstantOp constantOp, + SymbolTableCollection &symbolTables, + uint64_t alignment, Attribute memorySpace) { auto type = cast(constantOp.getType()); auto moduleOp = constantOp->getParentOfType(); if (!moduleOp) @@ -127,7 +128,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, // Create a builder without an insertion point. We will insert using the // symbol table to guarantee unique names. OpBuilder globalBuilder(moduleOp.getContext()); - SymbolTable symbolTable(moduleOp); + SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp); // Create a pretty name. SmallString<64> buf; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 824b505517119..67f373d912dd4 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -161,10 +161,12 @@ struct OneShotBufferizePass return signalPassFailure(); } + BufferizationState state; BufferizationStatistics statistics; ModuleOp moduleOp = getOperation(); if (opt.bufferizeFunctionBoundaries) { - if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) { + if (failed( + runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) { signalPassFailure(); return; } @@ -175,7 +177,7 @@ struct OneShotBufferizePass "'bufferize-function-boundaries'"); return signalPassFailure(); } - if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) { + if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) { signalPassFailure(); return; } @@ -275,6 +277,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener { LogicalResult bufferization::bufferizeOp(Operation *op, const BufferizationOptions &options, + BufferizationState &bufferizationState, BufferizationStatistics *statistics) { if (options.copyBeforeWrite) { AnalysisState state(options); @@ -331,7 +334,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op, << "//===-------------------------------------------===//\n" << "IR after bufferizing: " << nextOp->getName() << "\n"); rewriter.setInsertionPoint(nextOp); - if (failed(bufferizableOp.bufferize(rewriter, options))) { + if (failed( + bufferizableOp.bufferize(rewriter, options, bufferizationState))) { LLVM_DEBUG(llvm::dbgs() << "failed to bufferize\n" << "//===-------------------------------------------===//\n"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 755477713668e..080796208bfc1 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -239,7 +239,8 @@ struct CallOpInterface /// All function arguments are writable. It is the responsibility of the /// CallOp to insert buffer copies where necessary. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { func::CallOp callOp = cast(op); // 1. Compute the result types of the new CallOp. @@ -349,7 +350,8 @@ struct ReturnOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { #ifndef NDEBUG auto returnOp = cast(op); assert(isa(returnOp->getParentOp()) && @@ -418,7 +420,8 @@ struct FuncOpInterface /// All function bbArgs are writable unless they are explicitly marked as /// read-only. Callers must insert copies when needed. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto funcOp = cast(op); FunctionType funcType = funcOp.getFunctionType(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index 6e93b36d2d5a2..de820e9c8f8af 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -1365,10 +1365,9 @@ LogicalResult bufferization::analyzeOp(Operation *op, return success(!failedAnalysis); } -LogicalResult -bufferization::runOneShotBufferize(Operation *op, - const OneShotBufferizationOptions &options, - BufferizationStatistics *statistics) { +LogicalResult bufferization::runOneShotBufferize( + Operation *op, const OneShotBufferizationOptions &options, + BufferizationState &state, BufferizationStatistics *statistics) { // copy-before-write deactivates the analysis. It cannot be used together with // test-analysis-only. assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && @@ -1391,5 +1390,5 @@ bufferization::runOneShotBufferize(Operation *op, // Bufferize the op and its nested ops. If options.copyBeforeWrite is set, // a new buffer copy is allocated every time a buffer is written to. - return bufferizeOp(op, options, statistics); + return bufferizeOp(op, options, state, statistics); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index a025da8635135..90ceea4d69680 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -512,7 +512,7 @@ void mlir::bufferization::removeBufferizationAttributesInModule( LogicalResult mlir::bufferization::bufferizeModuleOp( ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationStatistics *statistics) { + BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); IRRewriter rewriter(moduleOp.getContext()); @@ -548,10 +548,10 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( // Buffer copies must be inserted before every write. OneShotBufferizationOptions updatedOptions = options; updatedOptions.copyBeforeWrite = true; - if (failed(bufferizeOp(funcOp, updatedOptions, statistics))) + if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics))) return failure(); } else { - if (failed(bufferizeOp(funcOp, options, statistics))) + if (failed(bufferizeOp(funcOp, options, state, statistics))) return failure(); } @@ -565,7 +565,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( // Functions were already bufferized. if (isa(&op) || op.hasTrait()) continue; - if (failed(bufferizeOp(&op, options, statistics))) + if (failed(bufferizeOp(&op, options, state, statistics))) return failure(); } @@ -577,7 +577,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( LogicalResult mlir::bufferization::runOneShotModuleBufferize( ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationStatistics *statistics) { + BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); assert(!(options.copyBeforeWrite && options.testAnalysisOnly) && @@ -606,7 +606,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( } if (options.testAnalysisOnly) return success(); - if (failed(bufferizeModuleOp(moduleOp, options, statistics))) + if (failed(bufferizeModuleOp(moduleOp, options, state, statistics))) return failure(); return success(); } diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp index 72f4a1a4f4c66..6a1546fb48683 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp @@ -43,7 +43,8 @@ struct BranchLikeOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { // The operands of this op are bufferized together with the block signature. return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index be158af09d398..b6a498a57c036 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -148,7 +148,8 @@ struct LinalgOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { return bufferizeDestinationStyleOpInterface( rewriter, cast(op), options); } @@ -174,7 +175,8 @@ struct SoftmaxOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto softmaxOp = cast(op); FailureOr inputBuffer = getBuffer(rewriter, softmaxOp.getInput(), options); @@ -202,6 +204,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels( LinalgOpInterfaceHelper< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >::registerOpInterface(ctx); SoftmaxOp::attachInterface(*ctx); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index b1340be04e011..f18a31b97967b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -263,7 +263,11 @@ Value linalg::bufferizeToAllocation( assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 && "expected single masked op"); OpBuilder::InsertionGuard g(rewriter); + + // Should the bufferization options and state be function arguments? bufferization::BufferizationOptions bufferizationOptions; + bufferization::BufferizationState bufferizationState; + Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator(); assert(isa(yieldOp) && "expected yield op terminator"); @@ -279,7 +283,7 @@ Value linalg::bufferizeToAllocation( // Bufferize terminator. rewriter.setInsertionPoint(yieldOp); if (failed(cast(yieldOp).bufferize( - rewriter, bufferizationOptions))) + rewriter, bufferizationOptions, bufferizationState))) return nullptr; // Erase dead to_tensor ops inside of the mask op. This is necessary because @@ -300,8 +304,9 @@ Value linalg::bufferizeToAllocation( for (OpOperand &use : result.getUses()) resultUses.push_back(&use); rewriter.setInsertionPoint(maskOp); - if (failed(cast(maskOp.getOperation()) - .bufferize(rewriter, bufferizationOptions))) + if (failed( + cast(maskOp.getOperation()) + .bufferize(rewriter, bufferizationOptions, bufferizationState))) return nullptr; // Set "restrict" attribute, indicating that no other tensor aliases with @@ -484,8 +489,11 @@ Value linalg::bufferizeToAllocation( auto bufferizableOp = dyn_cast(op); if (!bufferizableOp) return nullptr; + + // Should the bufferization options and states be function arguments? BufferizationOptions bufferizationOptions; - AnalysisState state(bufferizationOptions); + AnalysisState analysisState(bufferizationOptions); + BufferizationState bufferizationState; #ifndef NDEBUG if (!options.bufferizeDestinationOnly) { @@ -527,7 +535,7 @@ Value linalg::bufferizeToAllocation( }; for (OpResult result : tensorResults) { AliasingOpOperandList aliasingOperands = - state.getAliasingOpOperands(result); + analysisState.getAliasingOpOperands(result); for (const AliasingOpOperand &operand : aliasingOperands) { addOutOfPlaceOperand(operand.opOperand); for (OpOperand &resultUse : result.getUses()) @@ -535,7 +543,7 @@ Value linalg::bufferizeToAllocation( } } for (OpOperand &operand : op->getOpOperands()) { - if (!state.bufferizesToMemoryWrite(operand)) + if (!analysisState.bufferizesToMemoryWrite(operand)) continue; if (!isa(operand.get().getType())) continue; @@ -553,7 +561,7 @@ Value linalg::bufferizeToAllocation( Value alloc = createAllocationForTensor( rewriter, op->getLoc(), operand->get(), options, memorySpace); allocs.push_back(alloc); - if (!state.findDefinitions(operand).empty()) { + if (!analysisState.findDefinitions(operand).empty()) { // Initialize buffer with a copy of the operand data. Not needed if the // tensor is uninitialized. createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); @@ -575,7 +583,8 @@ Value linalg::bufferizeToAllocation( // Bufferize the op. rewriter.setInsertionPoint(op); - if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions))) + if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions, + bufferizationState))) return nullptr; // Set "restrict" attribute, indicating that no other tensor aliases with diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp index 926d580ac7852..104ec3e1449e5 100644 --- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp @@ -52,15 +52,21 @@ struct GlobalOpInterface bool hasTensorSemantics(Operation *) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &) const { + const BufferizationOptions &, + BufferizationState &state) const { auto globalOp = cast(op); if (!globalOp.getValue().has_value()) return globalOp.emitError("global op must have a value"); + SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( + globalOp->getParentWithTrait()); + + symbolTable.remove(globalOp); + auto tensorType = cast(globalOp.getType()); auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); - replaceOpWithNewBufferizedOp( + auto replacement = replaceOpWithNewBufferizedOp( rewriter, globalOp, globalOp.getSymName(), /*sym_visibility=*/globalOp.getSymVisibilityAttr(), /*type=*/cast(memrefType), @@ -68,6 +74,7 @@ struct GlobalOpInterface /*constant=*/!globalOp.getIsMutable(), /*alignment=*/nullptr); + symbolTable.insert(replacement); return success(); } }; @@ -91,7 +98,8 @@ struct GlobalLoadOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &) const { + const BufferizationOptions &, + BufferizationState &state) const { auto globalLoadOp = cast(op); auto tensorType = cast(globalLoadOp.getType()); @@ -121,7 +129,8 @@ struct GlobalStoreOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto globalStoreOp = cast(op); auto tensorType = cast(globalStoreOp.getValue().getType()); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index d6a9d8f6401f1..3ff1f5c49aece 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -95,7 +95,8 @@ struct ConditionOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto conditionOp = cast(op); auto whileOp = cast(conditionOp->getParentOp()); @@ -181,7 +182,8 @@ struct ExecuteRegionOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto executeRegionOp = cast(op); auto yieldOp = getUniqueYieldOp(executeRegionOp); TypeRange newResultTypes(yieldOp.getResults()); @@ -237,7 +239,8 @@ struct IfOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { OpBuilder::InsertionGuard g(rewriter); auto ifOp = cast(op); @@ -347,7 +350,8 @@ struct IndexSwitchOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { OpBuilder::InsertionGuard g(rewriter); auto switchOp = cast(op); @@ -722,7 +726,8 @@ struct ForOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto forOp = cast(op); Block *oldLoopBody = forOp.getBody(); @@ -939,7 +944,8 @@ struct WhileOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto whileOp = cast(op); // Indices of all bbArgs that have tensor type. These are the ones that @@ -1144,7 +1150,8 @@ struct YieldOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto yieldOp = cast(op); if (!isa(yieldOp->getParentOp())) @@ -1220,7 +1227,8 @@ struct ForallOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { OpBuilder::InsertionGuard guard(rewriter); auto forallOp = cast(op); int64_t rank = forallOp.getRank(); @@ -1327,7 +1335,8 @@ struct InParallelOpInterface : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &b, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { llvm_unreachable("op does not have any tensor OpOperands / OpResults"); return failure(); } diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp index 6c3b23937f98f..e8cab76d3c753 100644 --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -47,7 +47,8 @@ struct AssumingOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto assumingOp = cast(op); assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) && "only 1 block supported"); @@ -112,7 +113,8 @@ struct AssumingYieldOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto yieldOp = cast(op); SmallVector newResults; for (Value value : yieldOp.getOperands()) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp index 7734d1d258453..f952b68ba7e67 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -30,7 +30,8 @@ template struct SparseBufferizableOpInterfaceExternalModel : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { return op->emitError( "sparse_tensor ops must be bufferized with the sparsifier"); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 6e882a8d0ff30..7c7c64f2aef01 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -114,8 +114,11 @@ class SparsificationAndBufferizationPass return false; }); + bufferization::BufferizationState bufferizationState; + if (failed(bufferization::bufferizeModuleOp(cast(getOperation()), - updatedOptions))) + updatedOptions, + bufferizationState))) return failure(); bufferization::removeBufferizationAttributesInModule(getOperation()); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index c0e697292d2a0..ac1e90b9f9b35 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -83,7 +83,8 @@ struct CastOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto castOp = cast(op); // The result buffer still has the old (pre-cast) type. @@ -162,7 +163,8 @@ struct CollapseShapeOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto collapseShapeOp = cast(op); RankedTensorType tensorResultType = collapseShapeOp.getResultType(); FailureOr maybeBuffer = @@ -247,7 +249,8 @@ struct DimOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto dimOp = cast(op); FailureOr v = getBuffer(rewriter, dimOp.getSource(), options); if (failed(v)) @@ -271,7 +274,8 @@ struct EmptyOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto emptyOp = cast(op); // Optimization: Fold away the op if it has no uses. @@ -329,7 +333,8 @@ struct ExpandShapeOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto expandShapeOp = cast(op); auto tensorResultType = expandShapeOp.getResultType(); FailureOr buffer = @@ -367,7 +372,8 @@ struct ExtractSliceOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto extractSliceOp = cast(op); SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); @@ -432,7 +438,8 @@ struct ExtractOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto extractOp = cast(op); FailureOr srcMemref = getBuffer(rewriter, extractOp.getTensor(), options); @@ -474,7 +481,8 @@ struct FromElementsOpInterface bool bufferizesToAllocation(Operation *op, Value value) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto fromElementsOp = cast(op); auto tensorType = cast(fromElementsOp.getType()); @@ -586,7 +594,8 @@ struct GenerateOpInterface bool bufferizesToAllocation(Operation *op, Value value) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto generateOp = cast(op); auto type = generateOp.getResult().getType(); @@ -620,7 +629,8 @@ struct InsertOpInterface : public DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto insertOp = cast(op); FailureOr destMemref = getBuffer(rewriter, insertOp.getDest(), options); @@ -670,7 +680,8 @@ struct InsertSliceOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { // insert_slice ops arise from tiling and bufferizing them out-of-place is // generally a deal breaker. When used with loops, this ends up cloning the // whole tensor on every single iteration and is a symptom of a @@ -752,7 +763,8 @@ struct PadOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto padOp = cast(op); Location loc = padOp.getLoc(); RankedTensorType resultType = padOp.getResultType(); @@ -831,7 +843,8 @@ struct RankOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto rankOp = cast(op); FailureOr v = getBuffer(rewriter, rankOp.getTensor(), options); if (failed(v)) @@ -868,7 +881,8 @@ struct ReshapeOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto reshapeOp = cast(op); FailureOr srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options); @@ -940,7 +954,8 @@ struct ParallelInsertSliceOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { OpBuilder::InsertionGuard g(rewriter); auto parallelInsertSliceOp = cast(op); ParallelCombiningOpInterface parallelCombiningParent = @@ -1015,7 +1030,8 @@ struct SplatOpInterface bool bufferizesToAllocation(Operation *op, Value value) const { return true; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { OpBuilder::InsertionGuard g(rewriter); auto splatOp = cast(op); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index b2272c5fda876..45b6e7c512947 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -48,7 +48,8 @@ struct TransferReadOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto readOp = cast(op); assert(isa(readOp.getShapedType()) && "only tensor types expected"); @@ -103,7 +104,8 @@ struct TransferWriteOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto writeOp = cast(op); assert(isa(writeOp.getShapedType()) && "only tensor types expected"); @@ -148,7 +150,8 @@ struct GatherOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto gatherOp = cast(op); assert(isa(gatherOp.getBaseType()) && "only tensor types expected"); @@ -202,7 +205,8 @@ struct MaskOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto maskOp = cast(op); // Do not bufferize if the masked op is not bufferizable. @@ -279,7 +283,8 @@ struct YieldOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { + const BufferizationOptions &options, + BufferizationState &state) const { auto yieldOp = cast(op); // Only supported as a vector.mask terminator. From 07344323181d67c16fe435810e1900fd20e6bf89 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Sat, 17 May 2025 14:26:21 +0200 Subject: [PATCH 3/7] Add extension mechanism to BufferizationState --- .../IR/BufferizableOpInterface.h | 75 +++++++++++++++++-- .../Bufferization/Transforms/BufferUtils.h | 10 +++ .../BufferizableOpInterfaceImpl.cpp | 3 +- .../IR/BufferizableOpInterface.cpp | 4 - .../Bufferization/Transforms/BufferUtils.cpp | 39 ++++++++++ .../BufferizableOpInterfaceImpl.cpp | 8 +- 6 files changed, 123 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 891a5d9044852..e2c75b9b230fa 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -582,12 +582,77 @@ class AnalysisState { /// bufferization process. class BufferizationState { public: - /// Get the cached symbol tables. - /// The user is expected to update / invalidate the cached symbol tables if - /// the bufferized operation have the Symbol or SymbolTable traits. - SymbolTableCollection &getSymbolTables(); + /// Base class for BufferizationState extensions that allow BufferizationState + /// to contain user-specified information in the state object. The extension + /// mechanism of BufferizationState mirrors the one of OneShotAnalysisState. + class Extension { + public: + /// Base virtual destructor. + // Out-of-line definition ensures symbols are emitted in a single object + // file. + virtual ~Extension(); + + protected: + /// Constructs an extension of the given state object. + Extension(BufferizationState &state) : state(state) {} + + /// Provides read-only access to the parent OneShotAnalysisState object. + const BufferizationState &getBufferizationState() const { return state; } + + private: + /// Back-reference to the state that is being extended. + BufferizationState &state; + }; -private: + /// Adds a new Extension of the type specified as template parameter, + /// constructing it with the arguments provided. The extension is owned by the + /// BufferizationState. It is expected that the state does not already have an + /// extension of the same type. Extension constructors are expected to take a + /// reference to BufferizationState as first argument, automatically supplied + /// by this call. + template + Ty &addExtension(Args &&...args) { + static_assert(std::is_base_of::value, + "only a class derived from " + "BufferizationState::Extension is allowed"); + auto ptr = std::make_unique(*this, std::forward(args)...); + auto result = extensions.try_emplace(TypeID::get(), std::move(ptr)); + assert(result.second && "extension already added"); + return *static_cast(result.first->second.get()); + } + + /// Returns the extension of the specified type. + template + Ty *getExtension() { + static_assert(std::is_base_of::value, + "only a class derived from " + "BufferizationState::Extension is allowed"); + auto iter = extensions.find(TypeID::get()); + if (iter == extensions.end()) + return nullptr; + return static_cast(iter->second.get()); + } + + /// Returns the extension of the specified type. + template + const Ty *getExtension() const { + return const_cast(this)->getExtension(); + } + + /// Extensions attached to the state, identified by the TypeID of their type. + /// Only one extension of any given type is allowed. + DenseMap> extensions; +}; + +/// Extra bufferization state that is required for bufferization of operations +/// declaring a symbol or a symbol table. +struct SymbolBufferizationState : public BufferizationState::Extension { + SymbolBufferizationState(BufferizationState &state) + : BufferizationState::Extension(state) {} + + /// The cached symbol tables. + /// The user is expected to update / invalidate the cached symbol tables if + /// the bufferized operation has the Symbol or SymbolTable traits. SymbolTableCollection symbolTables; }; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h index adeb52cf9d7e6..da0cbe31b0420 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h @@ -29,6 +29,7 @@ class GlobalOp; } // namespace memref namespace bufferization { +class BufferizationState; /// A simple analysis that detects allocation operations. class BufferPlacementAllocs { @@ -126,6 +127,15 @@ FailureOr getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, Attribute memorySpace = {}); +FailureOr getGlobalFor(arith::ConstantOp op, + BufferizationState &state, + uint64_t alignment, + Attribute memorySpace); + +void removeSymbol(Operation *op, BufferizationState &state); + +void insertSymbol(Operation *op, BufferizationState &state); + } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index f646326ffc58f..1eabafaca261a 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -47,8 +47,7 @@ struct ConstantOpInterface // Create global memory segment and replace tensor with memref pointing to // that memory segment. FailureOr globalOp = - getGlobalFor(constantOp, state.getSymbolTables(), - options.bufferAlignment, memorySpace); + getGlobalFor(constantOp, state, options.bufferAlignment, memorySpace); if (failed(globalOp)) return failure(); memref::GlobalOp globalMemref = *globalOp; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 14fa4c1ed8159..1fc34051680f1 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -125,10 +125,6 @@ void AnalysisState::resetCache() { insideMutuallyExclusiveRegionsCache.clear(); } -SymbolTableCollection &BufferizationState::getSymbolTables() { - return symbolTables; -} - Region *bufferization::getNextEnclosingRepetitiveRegion( Region *region, const BufferizationOptions &options) { assert(isRepetitiveRegion(region, options) && "expected repetitive region"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index bb21f642ac077..a5aeb2d1ebb08 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -159,3 +159,42 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, global->moveBefore(&moduleOp.front()); return global; } + +namespace mlir::bufferization { +FailureOr getGlobalFor(arith::ConstantOp op, + BufferizationState &state, + uint64_t alignment, + Attribute memorySpace) { + if (auto *symbolBufferizationState = + state.getExtension()) { + // Use the cached symbol tables. + return getGlobalFor(op, symbolBufferizationState->symbolTables, alignment, + memorySpace); + } + + SymbolTableCollection symbolTables; + return getGlobalFor(op, symbolTables, alignment, memorySpace); +} + +void removeSymbol(Operation *op, BufferizationState &state) { + if (auto *symbolBufferizationState = + state.getExtension()) { + SymbolTable &symbolTable = + symbolBufferizationState->symbolTables.getSymbolTable( + op->getParentWithTrait()); + + symbolTable.remove(op); + } +} + +void insertSymbol(Operation *op, BufferizationState &state) { + if (auto *symbolBufferizationState = + state.getExtension()) { + SymbolTable &symbolTable = + symbolBufferizationState->symbolTables.getSymbolTable( + op->getParentWithTrait()); + + symbolTable.insert(op); + } +} +} // namespace mlir::bufferization diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp index 104ec3e1449e5..a69bc9e5088ae 100644 --- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -58,10 +59,7 @@ struct GlobalOpInterface if (!globalOp.getValue().has_value()) return globalOp.emitError("global op must have a value"); - SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( - globalOp->getParentWithTrait()); - - symbolTable.remove(globalOp); + bufferization::removeSymbol(globalOp, state); auto tensorType = cast(globalOp.getType()); auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType); @@ -74,7 +72,7 @@ struct GlobalOpInterface /*constant=*/!globalOp.getIsMutable(), /*alignment=*/nullptr); - symbolTable.insert(replacement); + bufferization::insertSymbol(replacement, state); return success(); } }; From 8bd6a16b6e6552385052c9129dbe5dd5f3034e0a Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Sat, 17 May 2025 15:12:55 +0200 Subject: [PATCH 4/7] Add missing implementation for Extension destructor --- mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 1fc34051680f1..0da720ad6da28 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -125,6 +125,8 @@ void AnalysisState::resetCache() { insideMutuallyExclusiveRegionsCache.clear(); } +BufferizationState::Extension::~Extension() = default; + Region *bufferization::getNextEnclosingRepetitiveRegion( Region *region, const BufferizationOptions &options) { assert(isRepetitiveRegion(region, options) && "expected repetitive region"); From 45e03837c11c263f21805974a205a813bae2b849 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Sat, 17 May 2025 15:17:01 +0200 Subject: [PATCH 5/7] Add option to enable caching of symbol tables --- .../mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h | 6 ++++++ .../TransformOps/BufferizationTransformOps.cpp | 5 +++++ mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp | 5 +++++ .../Transforms/SparsificationAndBufferizationPass.cpp | 5 +++++ 4 files changed, 21 insertions(+) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h index 15189d2c1cb87..fa6a08320bd60 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -52,6 +52,12 @@ struct OneShotBufferizationOptions : public BufferizationOptions { /// `AnalysisHeuristic::Fuzzer`. The fuzzer should be used only with /// `testAnalysisOnly = true`. unsigned analysisFuzzerSeed = 0; + + /// Enable caching of symbol tables. If enabled, the SymbolBufferizationState + /// class is attached to the bufferization state and the user is required to + /// keep the cached symbol tables consistent with respect to the performed + /// bufferizations. + bool cacheSymbolTables = false; }; /// State for analysis-enabled bufferization. This class keeps track of alias diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index 8bb7942304274..a6cae1f4dda33 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -85,6 +85,10 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, auto payloadOps = state.getPayloadOps(getTarget()); BufferizationState bufferizationState; + if (options.cacheSymbolTables) { + bufferizationState.addExtension(); + } + for (Operation *target : payloadOps) { if (!isa(target)) return emitSilenceableError() << "expected module or function target"; @@ -166,6 +170,7 @@ class BufferizationTransformDialectExtension registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" + >(); } }; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 67f373d912dd4..3f094684aa9f8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -162,6 +162,11 @@ struct OneShotBufferizePass } BufferizationState state; + + if (opt.cacheSymbolTables) { + state.addExtension(); + } + BufferizationStatistics statistics; ModuleOp moduleOp = getOperation(); if (opt.bufferizeFunctionBoundaries) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 7c7c64f2aef01..663f5e420b953 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -116,6 +116,11 @@ class SparsificationAndBufferizationPass bufferization::BufferizationState bufferizationState; + if (updatedOptions.cacheSymbolTables) { + bufferizationState + .addExtension(); + } + if (failed(bufferization::bufferizeModuleOp(cast(getOperation()), updatedOptions, bufferizationState))) From 019f5b96d1858965cd50bf6667b2d3e24196ec55 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Sun, 18 May 2025 18:43:02 +0200 Subject: [PATCH 6/7] Remove caching option and separate extension --- .../IR/BufferizableOpInterface.h | 11 +++--- .../Bufferization/Transforms/BufferUtils.h | 5 --- .../Transforms/OneShotAnalysis.h | 6 ---- .../BufferizableOpInterfaceImpl.cpp | 3 +- .../IR/BufferizableOpInterface.cpp | 4 +++ .../BufferizationTransformOps.cpp | 4 --- .../Bufferization/Transforms/BufferUtils.cpp | 35 ++++--------------- .../Bufferization/Transforms/Bufferize.cpp | 4 --- .../SparsificationAndBufferizationPass.cpp | 5 --- 9 files changed, 16 insertions(+), 61 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index e2c75b9b230fa..d644f49573a35 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -639,16 +639,13 @@ class BufferizationState { return const_cast(this)->getExtension(); } + /// Get a reference to the collection of cached symbol tables. + SymbolTableCollection &getSymbolTables(); + +private: /// Extensions attached to the state, identified by the TypeID of their type. /// Only one extension of any given type is allowed. DenseMap> extensions; -}; - -/// Extra bufferization state that is required for bufferization of operations -/// declaring a symbol or a symbol table. -struct SymbolBufferizationState : public BufferizationState::Extension { - SymbolBufferizationState(BufferizationState &state) - : BufferizationState::Extension(state) {} /// The cached symbol tables. /// The user is expected to update / invalidate the cached symbol tables if diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h index da0cbe31b0420..c08bd6c436133 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h @@ -127,11 +127,6 @@ FailureOr getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, Attribute memorySpace = {}); -FailureOr getGlobalFor(arith::ConstantOp op, - BufferizationState &state, - uint64_t alignment, - Attribute memorySpace); - void removeSymbol(Operation *op, BufferizationState &state); void insertSymbol(Operation *op, BufferizationState &state); diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h index fa6a08320bd60..15189d2c1cb87 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -52,12 +52,6 @@ struct OneShotBufferizationOptions : public BufferizationOptions { /// `AnalysisHeuristic::Fuzzer`. The fuzzer should be used only with /// `testAnalysisOnly = true`. unsigned analysisFuzzerSeed = 0; - - /// Enable caching of symbol tables. If enabled, the SymbolBufferizationState - /// class is attached to the bufferization state and the user is required to - /// keep the cached symbol tables consistent with respect to the performed - /// bufferizations. - bool cacheSymbolTables = false; }; /// State for analysis-enabled bufferization. This class keeps track of alias diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index 1eabafaca261a..f646326ffc58f 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -47,7 +47,8 @@ struct ConstantOpInterface // Create global memory segment and replace tensor with memref pointing to // that memory segment. FailureOr globalOp = - getGlobalFor(constantOp, state, options.bufferAlignment, memorySpace); + getGlobalFor(constantOp, state.getSymbolTables(), + options.bufferAlignment, memorySpace); if (failed(globalOp)) return failure(); memref::GlobalOp globalMemref = *globalOp; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 0da720ad6da28..d6224b012ac95 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -127,6 +127,10 @@ void AnalysisState::resetCache() { BufferizationState::Extension::~Extension() = default; +SymbolTableCollection &BufferizationState::getSymbolTables() { + return symbolTables; +} + Region *bufferization::getNextEnclosingRepetitiveRegion( Region *region, const BufferizationOptions &options) { assert(isRepetitiveRegion(region, options) && "expected repetitive region"); diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index a6cae1f4dda33..db1eb20512033 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -85,10 +85,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, auto payloadOps = state.getPayloadOps(getTarget()); BufferizationState bufferizationState; - if (options.cacheSymbolTables) { - bufferizationState.addExtension(); - } - for (Operation *target : payloadOps) { if (!isa(target)) return emitSilenceableError() << "expected module or function target"; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index a5aeb2d1ebb08..ff2c83d228dbb 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -161,40 +161,17 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, } namespace mlir::bufferization { -FailureOr getGlobalFor(arith::ConstantOp op, - BufferizationState &state, - uint64_t alignment, - Attribute memorySpace) { - if (auto *symbolBufferizationState = - state.getExtension()) { - // Use the cached symbol tables. - return getGlobalFor(op, symbolBufferizationState->symbolTables, alignment, - memorySpace); - } - - SymbolTableCollection symbolTables; - return getGlobalFor(op, symbolTables, alignment, memorySpace); -} - void removeSymbol(Operation *op, BufferizationState &state) { - if (auto *symbolBufferizationState = - state.getExtension()) { - SymbolTable &symbolTable = - symbolBufferizationState->symbolTables.getSymbolTable( - op->getParentWithTrait()); + SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( + op->getParentWithTrait()); - symbolTable.remove(op); - } + symbolTable.remove(op); } void insertSymbol(Operation *op, BufferizationState &state) { - if (auto *symbolBufferizationState = - state.getExtension()) { - SymbolTable &symbolTable = - symbolBufferizationState->symbolTables.getSymbolTable( - op->getParentWithTrait()); + SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable( + op->getParentWithTrait()); - symbolTable.insert(op); - } + symbolTable.insert(op); } } // namespace mlir::bufferization diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 3f094684aa9f8..38de525316f7a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -163,10 +163,6 @@ struct OneShotBufferizePass BufferizationState state; - if (opt.cacheSymbolTables) { - state.addExtension(); - } - BufferizationStatistics statistics; ModuleOp moduleOp = getOperation(); if (opt.bufferizeFunctionBoundaries) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 663f5e420b953..7c7c64f2aef01 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -116,11 +116,6 @@ class SparsificationAndBufferizationPass bufferization::BufferizationState bufferizationState; - if (updatedOptions.cacheSymbolTables) { - bufferizationState - .addExtension(); - } - if (failed(bufferization::bufferizeModuleOp(cast(getOperation()), updatedOptions, bufferizationState))) From 21006bc58e5befbd5b07286715b36abd3d2bac34 Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Thu, 22 May 2025 07:49:07 +0200 Subject: [PATCH 7/7] Remove extension mechanism from BUfferizationState --- .../IR/BufferizableOpInterface.h | 61 ------------------- .../IR/BufferizableOpInterface.cpp | 2 - 2 files changed, 63 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index d644f49573a35..43c97d57e1834 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -582,71 +582,10 @@ class AnalysisState { /// bufferization process. class BufferizationState { public: - /// Base class for BufferizationState extensions that allow BufferizationState - /// to contain user-specified information in the state object. The extension - /// mechanism of BufferizationState mirrors the one of OneShotAnalysisState. - class Extension { - public: - /// Base virtual destructor. - // Out-of-line definition ensures symbols are emitted in a single object - // file. - virtual ~Extension(); - - protected: - /// Constructs an extension of the given state object. - Extension(BufferizationState &state) : state(state) {} - - /// Provides read-only access to the parent OneShotAnalysisState object. - const BufferizationState &getBufferizationState() const { return state; } - - private: - /// Back-reference to the state that is being extended. - BufferizationState &state; - }; - - /// Adds a new Extension of the type specified as template parameter, - /// constructing it with the arguments provided. The extension is owned by the - /// BufferizationState. It is expected that the state does not already have an - /// extension of the same type. Extension constructors are expected to take a - /// reference to BufferizationState as first argument, automatically supplied - /// by this call. - template - Ty &addExtension(Args &&...args) { - static_assert(std::is_base_of::value, - "only a class derived from " - "BufferizationState::Extension is allowed"); - auto ptr = std::make_unique(*this, std::forward(args)...); - auto result = extensions.try_emplace(TypeID::get(), std::move(ptr)); - assert(result.second && "extension already added"); - return *static_cast(result.first->second.get()); - } - - /// Returns the extension of the specified type. - template - Ty *getExtension() { - static_assert(std::is_base_of::value, - "only a class derived from " - "BufferizationState::Extension is allowed"); - auto iter = extensions.find(TypeID::get()); - if (iter == extensions.end()) - return nullptr; - return static_cast(iter->second.get()); - } - - /// Returns the extension of the specified type. - template - const Ty *getExtension() const { - return const_cast(this)->getExtension(); - } - /// Get a reference to the collection of cached symbol tables. SymbolTableCollection &getSymbolTables(); private: - /// Extensions attached to the state, identified by the TypeID of their type. - /// Only one extension of any given type is allowed. - DenseMap> extensions; - /// The cached symbol tables. /// The user is expected to update / invalidate the cached symbol tables if /// the bufferized operation has the Symbol or SymbolTable traits. diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index d6224b012ac95..14fa4c1ed8159 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -125,8 +125,6 @@ void AnalysisState::resetCache() { insideMutuallyExclusiveRegionsCache.clear(); } -BufferizationState::Extension::~Extension() = default; - SymbolTableCollection &BufferizationState::getSymbolTables() { return symbolTables; }