Skip to content

[mlir] Extract forall_to_for logic into reusable function and add pass #89636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
/// loop range.
std::unique_ptr<Pass> createForLoopRangeFoldingPass();

/// Creates a pass that converts SCF forall loops to SCF for loops.
std::unique_ptr<Pass> createForallToForLoopPass();

// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> {
let constructor = "mlir::createForLoopRangeFoldingPass()";
}

def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
let summary = "Convert SCF forall loops to SCF for loops";
let constructor = "mlir::createForallToForLoopPass()";
}

def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@ class Value;
namespace scf {

class IfOp;
class ForallOp;
class ForOp;
class ParallelOp;
class WhileOp;

/// Try converting scf.forall into a set of nested scf.for loops.
/// The newly created scf.for ops will be returned through the `results`
/// vector if provided.
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
SmallVectorImpl<Operation *> *results = nullptr);

/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
Expand Down
33 changes: 8 additions & 25 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,12 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return diag;
}

rewriter.setInsertionPoint(target);

if (!target.getOutputs().empty()) {
return emitSilenceableError()
<< "unsupported shared outputs (didn't bufferize?)";
}

SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
SmallVector<OpFoldResult> steps = target.getMixedStep();

if (getNumResults() != lbs.size()) {
DiagnosedSilenceableFailure diag =
Expand All @@ -89,28 +85,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return diag;
}

auto loc = target.getLoc();
SmallVector<Value> ivs;
for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
auto loop = rewriter.create<scf::ForOp>(
loc, lbValue, ubValue, stepValue, ValueRange(),
[](OpBuilder &, Location, Value, ValueRange) {});
ivs.push_back(loop.getInductionVar());
rewriter.setInsertionPointToStart(loop.getBody());
rewriter.create<scf::YieldOp>(loc);
rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Operation *> opResults;
if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "failed to convert forall into for";
return diag;
}
rewriter.eraseOp(target.getBody()->getTerminator());
rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
ivs);
rewriter.eraseOp(target);

for (auto &&[i, iv] : llvm::enumerate(ivs)) {
results.set(cast<OpResult>(getTransformed()[i]),
{iv.getParentBlock()->getParentOp()});

for (auto &&[i, res] : llvm::enumerate(opResults)) {
results.set(cast<OpResult>(getTransformed()[i]), {res});
}
return DiagnosedSilenceableFailure::success();
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ForallToFor.cpp
ForToWhile.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
Expand Down
79 changes: 79 additions & 0 deletions mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ForallOp's into SCF.ForOp's.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/Transforms/Passes.h"

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir

using namespace llvm;
using namespace mlir;
using scf::ForallOp;
using scf::ForOp;
using scf::LoopNest;

LogicalResult
mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
SmallVectorImpl<Operation *> *results) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(forallOp);

Location loc = forallOp.getLoc();
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> steps =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);

SmallVector<Value> ivs = llvm::map_to_vector(
loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });

Block *innermostBlock = loopNest.loops.back().getBody();
rewriter.eraseOp(forallOp.getBody()->getTerminator());
rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
innermostBlock->getTerminator()->getIterator(),
ivs);
rewriter.eraseOp(forallOp);

if (results) {
llvm::move(loopNest.loops, std::back_inserter(*results));
}

return success();
}

namespace {
struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
void runOnOperation() override {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());

parentOp->walk([&](scf::ForallOp forallOp) {
if (failed(scf::forallToForLoop(rewriter, forallOp))) {
return signalPassFailure();
}
});
}
};
} // namespace

std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
return std::make_unique<ForallToForLoop>();
}
57 changes: 57 additions & 0 deletions mlir/test/Dialect/SCF/forall-to-for.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s

func.func private @callee(%i: index, %j: index)

// CHECK-LABEL: @two_iters
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @two_iters(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
return
}

// -----

func.func private @callee(%i: index, %j: index)

// CHECK-LABEL: @repeated
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @repeated(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
// CHECK: func.call @callee(%[[IV1]], %[[IV2]])
return
}

// -----

func.func private @callee(%i: index, %j: index, %k: index, %l: index)

// CHECK-LABEL: @nested
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
// CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]]
// CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]]
// CHECK: scf.for %[[IV3:.+]] = %{{.*}} to %[[UB3]]
// CHECK: scf.for %[[IV4:.+]] = %{{.*}} to %[[UB4]]
// CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
scf.forall (%i, %j) in (%ub1, %ub2) {
scf.forall (%k, %l) in (%ub3, %ub4) {
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
}
}
return
}
Loading