Skip to content

Commit 91c1157

Browse files
authored
Revert "[MLIR] Make OneShotModuleBufferize use OpInterface (#110322)" (#113124)
This reverts commit 2026501. Failing bot: * https://lab.llvm.org/staging/#/builders/125/builds/389
1 parent a6d6c00 commit 91c1157

File tree

11 files changed

+281
-316
lines changed

11 files changed

+281
-316
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
#include "mlir/IR/Operation.h"
1313
#include "mlir/IR/PatternMatch.h"
14-
#include "mlir/Interfaces/FunctionInterfaces.h"
1514
#include "mlir/Support/LLVM.h"
1615
#include "llvm/ADT/DenseMapInfoVariant.h"
1716
#include "llvm/ADT/SetVector.h"
@@ -261,9 +260,9 @@ struct BufferizationOptions {
261260
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
262261
/// Tensor -> MemRef type converter.
263262
/// Parameters: Value, memory space, func op, bufferization options
264-
using FunctionArgTypeConverterFn = std::function<BaseMemRefType(
265-
TensorType, Attribute memorySpace, FunctionOpInterface,
266-
const BufferizationOptions &)>;
263+
using FunctionArgTypeConverterFn =
264+
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
265+
func::FuncOp, const BufferizationOptions &)>;
267266
/// Tensor -> MemRef type converter.
268267
/// Parameters: Value, memory space, bufferization options
269268
using UnknownTypeConverterFn = std::function<BaseMemRefType(

mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,24 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
5050

5151
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
5252
/// indices.
53-
DenseMap<FunctionOpInterface, IndexMapping> equivalentFuncArgs;
53+
DenseMap<FuncOp, IndexMapping> equivalentFuncArgs;
5454

5555
/// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
56-
DenseMap<FunctionOpInterface, IndexToIndexListMapping> aliasingReturnVals;
56+
DenseMap<FuncOp, IndexToIndexListMapping> aliasingReturnVals;
5757

5858
/// A set of all read BlockArguments of FuncOps.
59-
DenseMap<FunctionOpInterface, BbArgIndexSet> readBbArgs;
59+
DenseMap<FuncOp, BbArgIndexSet> readBbArgs;
6060

6161
/// A set of all written-to BlockArguments of FuncOps.
62-
DenseMap<FunctionOpInterface, BbArgIndexSet> writtenBbArgs;
62+
DenseMap<FuncOp, BbArgIndexSet> writtenBbArgs;
6363

6464
/// Keep track of which FuncOps are fully analyzed or currently being
6565
/// analyzed.
66-
DenseMap<FunctionOpInterface, FuncOpAnalysisState> analyzedFuncOps;
66+
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
6767

6868
/// This function is called right before analyzing the given FuncOp. It
6969
/// initializes the data structures for the FuncOp in this state object.
70-
void startFunctionAnalysis(FunctionOpInterface funcOp);
70+
void startFunctionAnalysis(FuncOp funcOp);
7171
};
7272

7373
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "mlir/IR/TypeUtilities.h"
1919
#include "mlir/IR/Value.h"
2020
#include "mlir/Interfaces/ControlFlowInterfaces.h"
21-
#include "mlir/Interfaces/FunctionInterfaces.h"
2221
#include "llvm/ADT/ScopeExit.h"
2322
#include "llvm/Support/Debug.h"
2423

@@ -315,7 +314,7 @@ namespace {
315314
/// Default function arg type converter: Use a fully dynamic layout map.
316315
BaseMemRefType
317316
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
318-
FunctionOpInterface funcOp,
317+
func::FuncOp funcOp,
319318
const BufferizationOptions &options) {
320319
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
321320
}
@@ -362,7 +361,7 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
362361
void BufferizationOptions::setFunctionBoundaryTypeConversion(
363362
LayoutMapOption layoutMapOption) {
364363
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
365-
FunctionOpInterface funcOp,
364+
func::FuncOp funcOp,
366365
const BufferizationOptions &options) {
367366
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
368367
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace mlir {
2222
namespace bufferization {
2323
namespace func_ext {
2424

25-
void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) {
25+
void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
2626
analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
2727
auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping());
2828
auto createdAliasingResults =

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ using namespace mlir::bufferization;
7575
using namespace mlir::bufferization::func_ext;
7676

7777
/// A mapping of FuncOps to their callers.
78-
using FuncCallerMap = DenseMap<FunctionOpInterface, DenseSet<Operation *>>;
78+
using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
7979

8080
/// Get or create FuncAnalysisState.
8181
static FuncAnalysisState &
@@ -88,11 +88,10 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8888

8989
/// Return the unique ReturnOp that terminates `funcOp`.
9090
/// Return nullptr if there is no such unique ReturnOp.
91-
static Operation *getAssumedUniqueReturnOp(FunctionOpInterface funcOp) {
92-
Operation *returnOp = nullptr;
93-
for (Block &b : funcOp.getFunctionBody()) {
94-
auto candidateOp = b.getTerminator();
95-
if (candidateOp && candidateOp->hasTrait<OpTrait::ReturnLike>()) {
91+
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
92+
func::ReturnOp returnOp;
93+
for (Block &b : funcOp.getBody()) {
94+
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
9695
if (returnOp)
9796
return nullptr;
9897
returnOp = candidateOp;
@@ -127,16 +126,16 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
127126
/// Store function BlockArguments that are equivalent to/aliasing a returned
128127
/// value in FuncAnalysisState.
129128
static LogicalResult
130-
aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
131-
OneShotAnalysisState &state,
129+
aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
132130
FuncAnalysisState &funcState) {
133-
if (funcOp.getFunctionBody().empty()) {
131+
if (funcOp.getBody().empty()) {
134132
// No function body available. Conservatively assume that every tensor
135133
// return value may alias with any tensor bbArg.
136-
for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) {
134+
FunctionType type = funcOp.getFunctionType();
135+
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
137136
if (!isa<TensorType>(inputIt.value()))
138137
continue;
139-
for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) {
138+
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
140139
if (!isa<TensorType>(resultIt.value()))
141140
continue;
142141
int64_t returnIdx = resultIt.index();
@@ -148,7 +147,7 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
148147
}
149148

150149
// Support only single return-terminated block in the function.
151-
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
150+
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
152151
assert(returnOp && "expected func with single return op");
153152

154153
for (OpOperand &returnVal : returnOp->getOpOperands())
@@ -169,8 +168,8 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp,
169168
return success();
170169
}
171170

172-
static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
173-
bool isRead, bool isWritten) {
171+
static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
172+
bool isWritten) {
174173
OpBuilder b(funcOp.getContext());
175174
Attribute accessType;
176175
if (isRead && isWritten) {
@@ -190,12 +189,12 @@ static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx,
190189
/// function with unknown ops, we conservatively assume that such ops bufferize
191190
/// to a read + write.
192191
static LogicalResult
193-
funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
194-
OneShotAnalysisState &state,
192+
funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
195193
FuncAnalysisState &funcState) {
196-
for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) {
194+
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
195+
++idx) {
197196
// Skip non-tensor arguments.
198-
if (!isa<TensorType>(funcOp.getArgumentTypes()[idx]))
197+
if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
199198
continue;
200199
bool isRead;
201200
bool isWritten;
@@ -205,7 +204,7 @@ funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
205204
StringRef str = accessAttr.getValue();
206205
isRead = str == "read" || str == "read-write";
207206
isWritten = str == "write" || str == "read-write";
208-
} else if (funcOp.getFunctionBody().empty()) {
207+
} else if (funcOp.getBody().empty()) {
209208
// If the function has no body, conservatively assume that all args are
210209
// read + written.
211210
isRead = true;
@@ -231,33 +230,33 @@ funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp,
231230

232231
/// Remove bufferization attributes on FuncOp arguments.
233232
static void removeBufferizationAttributes(BlockArgument bbArg) {
234-
auto funcOp = cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
233+
auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
235234
funcOp.removeArgAttr(bbArg.getArgNumber(),
236235
BufferizationDialect::kBufferLayoutAttrName);
237236
funcOp.removeArgAttr(bbArg.getArgNumber(),
238237
BufferizationDialect::kWritableAttrName);
239238
}
240239

241-
static FunctionOpInterface getCalledFunction(CallOpInterface callOp) {
240+
/// Return the func::FuncOp called by `callOp`.
241+
static func::FuncOp getCalledFunction(func::CallOp callOp) {
242242
SymbolRefAttr sym =
243243
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
244244
if (!sym)
245245
return nullptr;
246-
return dyn_cast_or_null<FunctionOpInterface>(
246+
return dyn_cast_or_null<func::FuncOp>(
247247
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
248248
}
249249

250250
/// Gather equivalence info of CallOps.
251251
/// Note: This only adds new equivalence info if the called function was already
252252
/// analyzed.
253253
// TODO: This does not handle cyclic function call graphs etc.
254-
static void equivalenceAnalysis(FunctionOpInterface funcOp,
254+
static void equivalenceAnalysis(func::FuncOp funcOp,
255255
OneShotAnalysisState &state,
256256
FuncAnalysisState &funcState) {
257-
funcOp->walk([&](CallOpInterface callOp) {
258-
FunctionOpInterface calledFunction = getCalledFunction(callOp);
259-
if (!calledFunction)
260-
return WalkResult::skip();
257+
funcOp->walk([&](func::CallOp callOp) {
258+
func::FuncOp calledFunction = getCalledFunction(callOp);
259+
assert(calledFunction && "could not retrieved called func::FuncOp");
261260

262261
// No equivalence info available for the called function.
263262
if (!funcState.equivalentFuncArgs.count(calledFunction))
@@ -268,7 +267,7 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp,
268267
int64_t bbargIdx = it.second;
269268
if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
270269
continue;
271-
Value returnVal = callOp->getResult(returnIdx);
270+
Value returnVal = callOp.getResult(returnIdx);
272271
Value argVal = callOp->getOperand(bbargIdx);
273272
state.unionEquivalenceClasses(returnVal, argVal);
274273
}
@@ -278,9 +277,11 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp,
278277
}
279278

280279
/// Return "true" if the given function signature has tensor semantics.
281-
static bool hasTensorSignature(FunctionOpInterface funcOp) {
282-
return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred<TensorType>) ||
283-
llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred<TensorType>);
280+
static bool hasTensorSignature(func::FuncOp funcOp) {
281+
return llvm::any_of(funcOp.getFunctionType().getInputs(),
282+
llvm::IsaPred<TensorType>) ||
283+
llvm::any_of(funcOp.getFunctionType().getResults(),
284+
llvm::IsaPred<TensorType>);
284285
}
285286

286287
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
@@ -290,16 +291,16 @@ static bool hasTensorSignature(FunctionOpInterface funcOp) {
290291
/// retrieve the called FuncOp from any func::CallOp.
291292
static LogicalResult
292293
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
293-
SmallVectorImpl<FunctionOpInterface> &orderedFuncOps,
294+
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
294295
FuncCallerMap &callerMap) {
295296
// For each FuncOp, the set of functions called by it (i.e. the union of
296297
// symbols of all nested func::CallOp).
297-
DenseMap<FunctionOpInterface, DenseSet<FunctionOpInterface>> calledBy;
298+
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
298299
// For each FuncOp, the number of func::CallOp it contains.
299-
DenseMap<FunctionOpInterface, unsigned> numberCallOpsContainedInFuncOp;
300-
WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult {
301-
if (!funcOp.getFunctionBody().empty()) {
302-
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
300+
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
301+
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
302+
if (!funcOp.getBody().empty()) {
303+
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
303304
if (!returnOp)
304305
return funcOp->emitError()
305306
<< "cannot bufferize a FuncOp with tensors and "
@@ -308,10 +309,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
308309

309310
// Collect function calls and populate the caller map.
310311
numberCallOpsContainedInFuncOp[funcOp] = 0;
311-
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
312-
FunctionOpInterface calledFunction = getCalledFunction(callOp);
313-
if (!calledFunction)
314-
return WalkResult::skip();
312+
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
313+
func::FuncOp calledFunction = getCalledFunction(callOp);
314+
assert(calledFunction && "could not retrieved called func::FuncOp");
315315
// If the called function does not have any tensors in its signature, then
316316
// it is not necessary to bufferize the callee before the caller.
317317
if (!hasTensorSignature(calledFunction))
@@ -349,11 +349,11 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
349349
/// most generic layout map as function return types. After bufferizing the
350350
/// entire function body, a more concise memref type can potentially be used for
351351
/// the return type of the function.
352-
static void foldMemRefCasts(FunctionOpInterface funcOp) {
353-
if (funcOp.getFunctionBody().empty())
352+
static void foldMemRefCasts(func::FuncOp funcOp) {
353+
if (funcOp.getBody().empty())
354354
return;
355355

356-
Operation *returnOp = getAssumedUniqueReturnOp(funcOp);
356+
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
357357
SmallVector<Type> resultTypes;
358358

359359
for (OpOperand &operand : returnOp->getOpOperands()) {
@@ -365,8 +365,8 @@ static void foldMemRefCasts(FunctionOpInterface funcOp) {
365365
}
366366
}
367367

368-
auto newFuncType = FunctionType::get(funcOp.getContext(),
369-
funcOp.getArgumentTypes(), resultTypes);
368+
auto newFuncType = FunctionType::get(
369+
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
370370
funcOp.setType(newFuncType);
371371
}
372372

@@ -379,7 +379,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379379
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
380380

381381
// A list of functions in the order in which they are analyzed + bufferized.
382-
SmallVector<FunctionOpInterface> orderedFuncOps;
382+
SmallVector<func::FuncOp> orderedFuncOps;
383383

384384
// A mapping of FuncOps to their callers.
385385
FuncCallerMap callerMap;
@@ -388,7 +388,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
388388
return failure();
389389

390390
// Analyze ops.
391-
for (FunctionOpInterface funcOp : orderedFuncOps) {
391+
for (func::FuncOp funcOp : orderedFuncOps) {
392392
if (!state.getOptions().isOpAllowed(funcOp))
393393
continue;
394394

@@ -416,7 +416,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
416416

417417
void mlir::bufferization::removeBufferizationAttributesInModule(
418418
ModuleOp moduleOp) {
419-
moduleOp.walk([&](FunctionOpInterface op) {
419+
moduleOp.walk([&](func::FuncOp op) {
420420
for (BlockArgument bbArg : op.getArguments())
421421
removeBufferizationAttributes(bbArg);
422422
});
@@ -430,7 +430,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430430
IRRewriter rewriter(moduleOp.getContext());
431431

432432
// A list of functions in the order in which they are analyzed + bufferized.
433-
SmallVector<FunctionOpInterface> orderedFuncOps;
433+
SmallVector<func::FuncOp> orderedFuncOps;
434434

435435
// A mapping of FuncOps to their callers.
436436
FuncCallerMap callerMap;
@@ -439,11 +439,11 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
439439
return failure();
440440

441441
// Bufferize functions.
442-
for (FunctionOpInterface funcOp : orderedFuncOps) {
442+
for (func::FuncOp funcOp : orderedFuncOps) {
443443
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
444444
// would be invalidated.
445445

446-
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) {
446+
if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
447447
// This function was not analyzed and RaW conflicts were not resolved.
448448
// Buffer copies must be inserted before every write.
449449
OneShotBufferizationOptions updatedOptions = options;
@@ -463,7 +463,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
463463
// Bufferize all other ops.
464464
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
465465
// Functions were already bufferized.
466-
if (isa<FunctionOpInterface>(&op))
466+
if (isa<func::FuncOp>(&op))
467467
continue;
468468
if (failed(bufferizeOp(&op, options, statistics)))
469469
return failure();
@@ -490,12 +490,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
490490
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
491491
// not be analyzed. Ops in these FuncOps will not be analyzed as well.
492492
OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
493-
auto func = dyn_cast<FunctionOpInterface>(op);
493+
auto func = dyn_cast<func::FuncOp>(op);
494494
if (!func)
495-
func = op->getParentOfType<FunctionOpInterface>();
495+
func = op->getParentOfType<func::FuncOp>();
496496
if (func)
497497
return llvm::is_contained(options.noAnalysisFuncFilter,
498-
func.getName());
498+
func.getSymName());
499499
return false;
500500
};
501501
OneShotBufferizationOptions updatedOptions(options);

0 commit comments

Comments
 (0)