diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h index 921c3c3e8c7db..186e83a57580f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h @@ -10,10 +10,8 @@ #define MLIR_DIALECT_LINALG_TRANSFORMS_HOISTING_H_ namespace mlir { +class Operation; class RewriterBase; -namespace func { -class FuncOp; -} // namespace func namespace scf { class ForOp; } // namespace scf @@ -43,7 +41,7 @@ namespace linalg { /// /// WARNING: This hoisting does not model parallelism and is generally incorrect /// when used on distributed loops with memref semantics! -void hoistRedundantVectorTransfers(func::FuncOp func); +void hoistRedundantVectorTransfers(Operation *root); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 80ce97ee3437a..34c9b2c282965 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -73,16 +73,16 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, return true; } -void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { +void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) { bool changed = true; while (changed) { changed = false; // First move loop invariant ops outside of their loop. This needs to be // done before as we cannot move ops without interrupting the function walk. - func.walk( + root->walk( [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); - func.walk([&](vector::TransferReadOp transferRead) { + root->walk([&](vector::TransferReadOp transferRead) { if (!isa(transferRead.getShapedType())) return WalkResult::advance();