Skip to content

[mlir][transf] Traits and interf. of alternatives, foreach, and yield. #112169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"

def AlternativesOp : TransformDialectOp<"alternatives",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorRegions",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
IsolatedFromAbove, PossibleTopLevelTransformOpTrait,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
Expand All @@ -37,8 +34,8 @@ def AlternativesOp : TransformDialectOp<"alternatives",
sequence of transform operations to be applied to the same payload IR. The
regions are visited in order of appearance, and transforms in them are
applied in their respective order of appearance. If one of these transforms
fails to apply, the remaining ops in the same region are skipped an the next
region is attempted. If all transformations in a region succeed, the
fails to apply, the remaining ops in the same region are skipped and the
next region is attempted. If all transformations in a region succeed, the
remaining regions are skipped and the entire "alternatives" transformation
succeeds. If all regions contained a failing transformation, the entire
"alternatives" transformation fails.
Expand Down Expand Up @@ -90,6 +87,13 @@ def AlternativesOp : TransformDialectOp<"alternatives",
transform.yield %arg0 : !transform.any_op
}
```

Note that this operation does not implement the `RegionBranchOpInterface`.
That interface verifies that the operands and results passed across the
control flow edges are equal (or compatible). In particular, it expects the
result passed from a region to its successor to be the argument of that
region; however, the argument of all `alternatives` regions are always
provided by the parent op and never by the precedessor region.
}];

let arguments = (ins Optional<TransformHandleTypeInterface>:$scope);
Expand Down Expand Up @@ -610,8 +614,6 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
def ForeachOp : TransformDialectOp<"foreach",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getSuccessorRegions", "getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
]> {
let summary = "Executes the body for each element of the payload";
Expand Down Expand Up @@ -646,6 +648,14 @@ def ForeachOp : TransformDialectOp<"foreach",
sequence fails immediately with the same failure, leaving the payload IR in
a potentially invalid state, i.e., this operation offers no transformation
rollback capabilities.

Note that this operation does not implement the `RegionBranchOpInterface`.
That interface verifies that the operands and results passed across the
control flow edges are equal (or compatible). In particular, it expects the
result passed from a region to the parent to *be* the result of that op;
however, the result of the `body` region only *contributes* to the result
in that the result of the op is an aggregation of of the results of all
iterations of the body.
}];

let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
Expand Down Expand Up @@ -1358,7 +1368,8 @@ def VerifyOp : TransformDialectOp<"verify",
}

def YieldOp : TransformDialectOp<"yield",
[Terminator, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
[Terminator, ReturnLike,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I missed this and noticed on the other commit. I don't think yield is return-like. Is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was proposed after #110322 broke things. But it got since then reversed, so it currently isn't "needed." The more I think about it, the more I doubt that #110322 is actually a good idea: it assumes that FunctionOpInterface has a ReturnLike terminator, which it shouldn't.

DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Yields operation handles from a transform IR region";
let description = [{
This terminator operation yields operation handles from regions of the
Expand Down
66 changes: 2 additions & 64 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,38 +94,6 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//

OperandRange
transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
if (!point.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
}

void transform::AlternativesOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
for (Region &alternative : llvm::drop_begin(
getAlternatives(),
point.isParent() ? 0
: point.getRegionOrNull()->getRegionNumber() + 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
if (!point.isParent())
regions.emplace_back(getOperation()->getResults());
}

void transform::AlternativesOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
(void)operands;
// The region corresponding to the first alternative is always executed, the
// remaining may or may not be executed.
bounds.reserve(getNumRegions());
bounds.emplace_back(1, 1);
bounds.resize(getNumRegions(), InvocationBounds(0, 1));
}

static void forwardEmptyOperands(Block *block, transform::TransformState &state,
transform::TransformResults &results) {
for (const auto &res : block->getParentOp()->getOpResults())
Expand Down Expand Up @@ -1500,28 +1468,6 @@ void transform::ForeachOp::getEffects(
producesHandle(getOperation()->getOpResults(), effects);
}

void transform::ForeachOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
Region *bodyRegion = &getBody();
if (point.isParent()) {
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
return;
}

// Branch back to the region or the parent.
assert(point == getBody() && "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
regions.emplace_back();
}

OperandRange
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// Each block argument handle is mapped to a subset (one op to be precise)
// of the payload of the corresponding `targets` operand of ForeachOp.
assert(point == getBody() && "unexpected region index");
return getOperation()->getOperands();
}

transform::YieldOp transform::ForeachOp::getYieldOp() {
return cast<transform::YieldOp>(getBody().front().getTerminator());
}
Expand Down Expand Up @@ -2702,16 +2648,8 @@ transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {

void transform::SequenceOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent()) {
Region *bodyRegion = &getBody();
regions.emplace_back(bodyRegion, getNumOperands() != 0
? bodyRegion->getArguments()
: Block::BlockArgListType());
return;
}

assert(point == getBody() && "unexpected region index");
regions.emplace_back(getOperation()->getResults());
if (point.getRegionOrNull() == &getBody())
regions.emplace_back(getResults());
}

void transform::SequenceOp::getRegionInvocationBounds(
Expand Down
Loading