Skip to content

Commit b467dae

Browse files
committed
[mlir] Extract forall_to_for logic into reusable function and add pass
1 parent 330d898 commit b467dae

File tree

7 files changed

+160
-25
lines changed

7 files changed

+160
-25
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/Passes.h

+3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
5959
/// loop range.
6060
std::unique_ptr<Pass> createForLoopRangeFoldingPass();
6161

62+
/// Creates a pass that converts SCF forall loops to SCF for loops.
63+
std::unique_ptr<Pass> createForallToForLoopPass();
64+
6265
// Creates a pass which lowers for loops into while loops.
6366
std::unique_ptr<Pass> createForToWhileLoopPass();
6467

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

+5
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> {
120120
let constructor = "mlir::createForLoopRangeFoldingPass()";
121121
}
122122

123+
def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
124+
let summary = "Convert SCF forall loops to SCF for loops";
125+
let constructor = "mlir::createForallToForLoopPass()";
126+
}
127+
123128
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
124129
let summary = "Convert SCF for loops to SCF while loops";
125130
let constructor = "mlir::createForToWhileLoopPass()";

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

+7
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,17 @@ class Value;
2828
namespace scf {
2929

3030
class IfOp;
31+
class ForallOp;
3132
class ForOp;
3233
class ParallelOp;
3334
class WhileOp;
3435

36+
/// Try converting scf.forall into a set of nested scf.for loops.
37+
/// The newly created scf.for ops will be returned through the `results`
38+
/// vector if provided.
39+
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
40+
SmallVectorImpl<Operation *> *results = nullptr);
41+
3542
/// Fuses all adjacent scf.parallel operations with identical bounds and step
3643
/// into one scf.parallel operations. Uses a naive aliasing and dependency
3744
/// analysis.

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

+8-25
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,12 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
6969
return diag;
7070
}
7171

72-
rewriter.setInsertionPoint(target);
73-
7472
if (!target.getOutputs().empty()) {
7573
return emitSilenceableError()
7674
<< "unsupported shared outputs (didn't bufferize?)";
7775
}
7876

7977
SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
80-
SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
81-
SmallVector<OpFoldResult> steps = target.getMixedStep();
8278

8379
if (getNumResults() != lbs.size()) {
8480
DiagnosedSilenceableFailure diag =
@@ -89,28 +85,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
8985
return diag;
9086
}
9187

92-
auto loc = target.getLoc();
93-
SmallVector<Value> ivs;
94-
for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
95-
Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
96-
Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
97-
Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
98-
auto loop = rewriter.create<scf::ForOp>(
99-
loc, lbValue, ubValue, stepValue, ValueRange(),
100-
[](OpBuilder &, Location, Value, ValueRange) {});
101-
ivs.push_back(loop.getInductionVar());
102-
rewriter.setInsertionPointToStart(loop.getBody());
103-
rewriter.create<scf::YieldOp>(loc);
104-
rewriter.setInsertionPointToStart(loop.getBody());
88+
SmallVector<Operation *> opResults;
89+
if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
90+
DiagnosedSilenceableFailure diag = emitSilenceableError()
91+
<< "failed to convert forall into for";
92+
return diag;
10593
}
106-
rewriter.eraseOp(target.getBody()->getTerminator());
107-
rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
108-
ivs);
109-
rewriter.eraseOp(target);
110-
111-
for (auto &&[i, iv] : llvm::enumerate(ivs)) {
112-
results.set(cast<OpResult>(getTransformed()[i]),
113-
{iv.getParentBlock()->getParentOp()});
94+
95+
for (auto &&[i, res] : llvm::enumerate(opResults)) {
96+
results.set(cast<OpResult>(getTransformed()[i]), {res});
11497
}
11598
return DiagnosedSilenceableFailure::success();
11699
}

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
22
BufferDeallocationOpInterfaceImpl.cpp
33
BufferizableOpInterfaceImpl.cpp
44
Bufferize.cpp
5+
ForallToFor.cpp
56
ForToWhile.cpp
67
LoopCanonicalization.cpp
78
LoopPipelining.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Transforms SCF.ForallOp's into SCF.ForOp's.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
14+
15+
#include "mlir/Dialect/SCF/IR/SCF.h"
16+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
namespace mlir {
20+
#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
21+
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22+
} // namespace mlir
23+
24+
using namespace llvm;
25+
using namespace mlir;
26+
using scf::ForallOp;
27+
using scf::ForOp;
28+
using scf::LoopNest;
29+
30+
LogicalResult
31+
mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
32+
SmallVectorImpl<Operation *> *results) {
33+
OpBuilder::InsertionGuard guard(rewriter);
34+
rewriter.setInsertionPoint(forallOp);
35+
36+
Location loc = forallOp.getLoc();
37+
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
38+
rewriter, loc, forallOp.getMixedLowerBound());
39+
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
40+
rewriter, loc, forallOp.getMixedUpperBound());
41+
SmallVector<Value> steps =
42+
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
43+
LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
44+
45+
SmallVector<Value> ivs = llvm::map_to_vector(
46+
loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
47+
48+
Block *innermostBlock = loopNest.loops.back().getBody();
49+
rewriter.eraseOp(forallOp.getBody()->getTerminator());
50+
rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
51+
innermostBlock->getTerminator()->getIterator(),
52+
ivs);
53+
rewriter.eraseOp(forallOp);
54+
55+
if (results) {
56+
llvm::move(loopNest.loops, std::back_inserter(*results));
57+
}
58+
59+
return success();
60+
}
61+
62+
namespace {
63+
struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
64+
void runOnOperation() override {
65+
Operation *parentOp = getOperation();
66+
IRRewriter rewriter(parentOp->getContext());
67+
68+
parentOp->walk([&](scf::ForallOp forallOp) {
69+
if (failed(scf::forallToForLoop(rewriter, forallOp))) {
70+
return signalPassFailure();
71+
}
72+
});
73+
}
74+
};
75+
} // namespace
76+
77+
std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
78+
return std::make_unique<ForallToForLoop>();
79+
}
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
// CHECK-LABEL: @two_iters
6+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
7+
func.func @two_iters(%ub1: index, %ub2: index) {
8+
scf.forall (%i, %j) in (%ub1, %ub2) {
9+
func.call @callee(%i, %j) : (index, index) -> ()
10+
}
11+
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
12+
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
13+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
14+
return
15+
}
16+
17+
// -----
18+
19+
func.func private @callee(%i: index, %j: index)
20+
21+
// CHECK-LABEL: @repeated
22+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
23+
func.func @repeated(%ub1: index, %ub2: index) {
24+
scf.forall (%i, %j) in (%ub1, %ub2) {
25+
func.call @callee(%i, %j) : (index, index) -> ()
26+
}
27+
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
28+
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
29+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
30+
scf.forall (%i, %j) in (%ub1, %ub2) {
31+
func.call @callee(%i, %j) : (index, index) -> ()
32+
}
33+
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
34+
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
35+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
36+
return
37+
}
38+
39+
// -----
40+
41+
func.func private @callee(%i: index, %j: index, %k: index, %l: index)
42+
43+
// CHECK-LABEL: @nested
44+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
45+
func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
46+
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
47+
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
48+
// CHECK: scf.for %[[IV3:.+]] = %{{.*}} to %[[UB3]]
49+
// CHECK: scf.for %[[IV4:.+]] = %{{.*}} to %[[UB4]]
50+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
51+
scf.forall (%i, %j) in (%ub1, %ub2) {
52+
scf.forall (%k, %l) in (%ub3, %ub4) {
53+
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
54+
}
55+
}
56+
return
57+
}

0 commit comments

Comments
 (0)