diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index b946fc8875860..195d794e5a835 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -24,10 +24,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" def AlternativesOp : TransformDialectOp<"alternatives", - [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, IsolatedFromAbove, PossibleTopLevelTransformOpTrait, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { @@ -37,8 +34,8 @@ def AlternativesOp : TransformDialectOp<"alternatives", sequence of transform operations to be applied to the same payload IR. The regions are visited in order of appearance, and transforms in them are applied in their respective order of appearance. If one of these transforms - fails to apply, the remaining ops in the same region are skipped an the next - region is attempted. If all transformations in a region succeed, the + fails to apply, the remaining ops in the same region are skipped and the + next region is attempted. If all transformations in a region succeed, the remaining regions are skipped and the entire "alternatives" transformation succeeds. If all regions contained a failing transformation, the entire "alternatives" transformation fails. @@ -90,6 +87,13 @@ def AlternativesOp : TransformDialectOp<"alternatives", transform.yield %arg0 : !transform.any_op } ``` + + Note that this operation does not implement the `RegionBranchOpInterface`. + That interface verifies that the operands and results passed across the + control flow edges are equal (or compatible). In particular, it expects the + result passed from a region to its successor to be the argument of that + region; however, the argument of all `alternatives` regions are always + provided by the parent op and never by the precedessor region. }]; let arguments = (ins Optional:$scope); @@ -610,8 +614,6 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [ def ForeachOp : TransformDialectOp<"foreach", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp"> ]> { let summary = "Executes the body for each element of the payload"; @@ -646,6 +648,14 @@ def ForeachOp : TransformDialectOp<"foreach", sequence fails immediately with the same failure, leaving the payload IR in a potentially invalid state, i.e., this operation offers no transformation rollback capabilities. + + Note that this operation does not implement the `RegionBranchOpInterface`. + That interface verifies that the operands and results passed across the + control flow edges are equal (or compatible). In particular, it expects the + result passed from a region to the parent to *be* the result of that op; + however, the result of the `body` region only *contributes* to the result + in that the result of the op is an aggregation of of the results of all + iterations of the body. }]; let arguments = (ins Variadic:$targets, @@ -1358,7 +1368,8 @@ def VerifyOp : TransformDialectOp<"verify", } def YieldOp : TransformDialectOp<"yield", - [Terminator, DeclareOpInterfaceMethods]> { + [Terminator, ReturnLike, + DeclareOpInterfaceMethods]> { let summary = "Yields operation handles from a transform IR region"; let description = [{ This terminator operation yields operation handles from regions of the diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 590cae9aa0d66..02b5c312d69d9 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -94,38 +94,6 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, // AlternativesOp //===----------------------------------------------------------------------===// -OperandRange -transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { - if (!point.isParent() && getOperation()->getNumOperands() == 1) - return getOperation()->getOperands(); - return OperandRange(getOperation()->operand_end(), - getOperation()->operand_end()); -} - -void transform::AlternativesOp::getSuccessorRegions( - RegionBranchPoint point, SmallVectorImpl ®ions) { - for (Region &alternative : llvm::drop_begin( - getAlternatives(), - point.isParent() ? 0 - : point.getRegionOrNull()->getRegionNumber() + 1)) { - regions.emplace_back(&alternative, !getOperands().empty() - ? alternative.getArguments() - : Block::BlockArgListType()); - } - if (!point.isParent()) - regions.emplace_back(getOperation()->getResults()); -} - -void transform::AlternativesOp::getRegionInvocationBounds( - ArrayRef operands, SmallVectorImpl &bounds) { - (void)operands; - // The region corresponding to the first alternative is always executed, the - // remaining may or may not be executed. - bounds.reserve(getNumRegions()); - bounds.emplace_back(1, 1); - bounds.resize(getNumRegions(), InvocationBounds(0, 1)); -} - static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results) { for (const auto &res : block->getParentOp()->getOpResults()) @@ -1500,28 +1468,6 @@ void transform::ForeachOp::getEffects( producesHandle(getOperation()->getOpResults(), effects); } -void transform::ForeachOp::getSuccessorRegions( - RegionBranchPoint point, SmallVectorImpl ®ions) { - Region *bodyRegion = &getBody(); - if (point.isParent()) { - regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - return; - } - - // Branch back to the region or the parent. - assert(point == getBody() && "unexpected region index"); - regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - regions.emplace_back(); -} - -OperandRange -transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { - // Each block argument handle is mapped to a subset (one op to be precise) - // of the payload of the corresponding `targets` operand of ForeachOp. - assert(point == getBody() && "unexpected region index"); - return getOperation()->getOperands(); -} - transform::YieldOp transform::ForeachOp::getYieldOp() { return cast(getBody().front().getTerminator()); } @@ -2702,16 +2648,8 @@ transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { void transform::SequenceOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { - if (point.isParent()) { - Region *bodyRegion = &getBody(); - regions.emplace_back(bodyRegion, getNumOperands() != 0 - ? bodyRegion->getArguments() - : Block::BlockArgListType()); - return; - } - - assert(point == getBody() && "unexpected region index"); - regions.emplace_back(getOperation()->getResults()); + if (point.getRegionOrNull() == &getBody()) + regions.emplace_back(getResults()); } void transform::SequenceOp::getRegionInvocationBounds(