diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h index 26cc9b16cd9ca..d14d63e56dc76 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h @@ -20,6 +20,7 @@ namespace func { class FuncOp; } // namespace func namespace scf { +class ForallOp; class ForOp; class IfOp; } // namespace scf diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index 6f48b005bfcbf..207a004c54ef5 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -40,6 +40,34 @@ def ApplySCFStructuralConversionPatternsOp : Op; +def ForallToForOp : Op]> { + let summary = "Converts scf.forall into a nest of scf.for operations"; + let description = [{ + Converts the `scf.forall` operation pointed to by the given handle into a + set of nested `scf.for` operations. Each new operation corresponds to one + induction variable of the original "multifor" loop. + + The operand handle must be associated with exactly one payload operation. + + Loops with shared outputs are currently not supported. + + #### Return Modes + + Consumes the operand handle. Produces a silenceable failure if the operand + is not associated with a single `scf.forall` payload operation. + Returns as many handles as the given `forall` op has induction variables + that are associated with the generated `scf.for` loops. + Produces a silenceable failure if another number of resulting handles is + requested. + }]; + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs Variadic:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; +} + def GetParentForOp : Op]> { diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 88ddd22eea46b..d7e8c38478ced 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -9,6 +9,8 @@ #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" @@ -17,8 +19,11 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/OpDefinition.h" using namespace mlir; using namespace mlir::affine; @@ -47,6 +52,7 @@ void transform::ApplySCFStructuralConversionPatternsOp:: //===----------------------------------------------------------------------===// // GetParentForOp //===----------------------------------------------------------------------===// + DiagnosedSilenceableFailure transform::GetParentForOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -76,6 +82,72 @@ transform::GetParentForOp::apply(transform::TransformRewriter &rewriter, return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// ForallToForOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto payload = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(payload)) + return emitSilenceableError() << "expected a single payload op"; + + auto target = dyn_cast(*payload.begin()); + if (!target) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "expected the payload to be scf.forall"; + diag.attachNote((*payload.begin())->getLoc()) << "payload op"; + return diag; + } + + rewriter.setInsertionPoint(target); + + if (!target.getOutputs().empty()) { + return emitSilenceableError() + << "unsupported shared outputs (didn't bufferize?)"; + } + + SmallVector lbs = target.getMixedLowerBound(); + SmallVector ubs = target.getMixedUpperBound(); + SmallVector steps = target.getMixedStep(); + + if (getNumResults() != lbs.size()) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "op expects as many results (" << getNumResults() + << ") as payload has induction variables (" << lbs.size() << ")"; + diag.attachNote(target.getLoc()) << "payload op"; + return diag; + } + + auto loc = target.getLoc(); + SmallVector ivs; + for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) { + Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb); + Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub); + Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step); + auto loop = rewriter.create( + loc, lbValue, ubValue, stepValue, ValueRange(), + [](OpBuilder &, Location, Value, ValueRange) {}); + ivs.push_back(loop.getInductionVar()); + rewriter.setInsertionPointToStart(loop.getBody()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + rewriter.eraseOp(target.getBody()->getTerminator()); + rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(), + ivs); + rewriter.eraseOp(target); + + for (auto &&[i, iv] : llvm::enumerate(ivs)) { + results.set(cast(getTransformed()[i]), + {iv.getParentBlock()->getParentOp()}); + } + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // LoopOutlineOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir b/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir new file mode 100644 index 0000000000000..4b46c68d06d35 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-op-forall-to-for.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics | FileCheck %s + +func.func private @callee(%i: index, %j: index) + +// CHECK-LABEL: @two_iters +// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index +func.func @two_iters(%ub1: index, %ub2: index) { + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]] + // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]] + // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) + return +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +} + +// ----- + +func.func private @callee(%i: index, %j: index) + +func.func @repeated(%ub1: index, %ub2: index) { + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + return +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{expected a single payload op}} + transform.loop.forall_to_for %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) +} + +// ----- + +func.func private @callee(%i: index, %j: index) + +func.func @repeated(%ub1: index, %ub2: index) { + // expected-note @below {{payload op}} + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + return +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{op expects as many results (1) as payload has induction variables (2)}} + transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op +} + +// ----- + +// expected-note @below {{payload op}} +func.func private @callee(%i: index, %j: index) + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{expected the payload to be scf.forall}} + transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op +}