diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index b14c89eadb097..2fd0e80db96fe 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -481,8 +481,16 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [ This operation consumes the operand and produces a new handle associated with the same payload. This is necessary to trigger invalidation of handles to any of the payload operations nested in the payload operations associated - with the operand, as those are likely to be modified by actions. Note that - the root payload operation associated with the operand are not matched. + with the operand, as those are likely to be modified by actions. + + By default, the root payload operation associated with the operand is not + matched. This is to support the conservative case where applied actions may + invalidate the root payload operation. If the optional `restrict_root` + attribute is set, the root operand is guaranteed to not be invalidated by any + of the applied actions. In such cases, the root payload operation is also + matched. This is useful because matching the root payload operation is a + common idiom, when e.g. matching a func.func directly and operations nested + under it. The operation succeeds if none of the matchers produced a definite failure during application and if all of the applied actions produced success. Note @@ -495,13 +503,19 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [ }]; let arguments = (ins TransformHandleTypeInterface:$root, + UnitAttr:$restrict_root, SymbolRefArrayAttr:$matchers, SymbolRefArrayAttr:$actions); let results = (outs TransformHandleTypeInterface:$updated); - let assemblyFormat = - "`in` $root custom($matchers, $actions) " - "attr-dict `:` functional-type($root, $updated)"; + let assemblyFormat = [{ + (`restrict_root` $restrict_root^)? + `in` + $root + custom($matchers, $actions) + attr-dict + `:` functional-type($root, $updated) + }]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 8db77b6059dd2..514a75b5d5904 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -850,8 +850,9 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, for (Operation *root : state.getPayloadOps(getRoot())) { WalkResult walkResult = root->walk([&](Operation *op) { - // Skip over the root op itself so we don't invalidate it. - if (op == root) + // If getRestrictRoot is not present, skip over the root op itself so we + // don't invalidate it. + if (!getRestrictRoot() && op == root) return WalkResult::advance(); DEBUG_MATCHER({ @@ -1556,10 +1557,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( ::std::optional<::mlir::Operation *> maybeCurrent, transform::TransformResults &results, transform::TransformState &state) { if (!maybeCurrent.has_value()) { - DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; + DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; }); return DiagnosedSilenceableFailure::success(); } - DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; + DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; }); return emitSilenceableError() << "operation is not empty"; } @@ -1961,7 +1962,8 @@ void transform::NamedSequenceOp::build(OpBuilder &builder, state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(symName)); state.addAttribute(getFunctionTypeAttrName(state.name), - TypeAttr::get(FunctionType::get(builder.getContext(), rootType, resultTypes))); + TypeAttr::get(FunctionType::get(builder.getContext(), + rootType, resultTypes))); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir index 9489aadac843d..c88945c8a5c60 100644 --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -100,12 +100,13 @@ module attributes { transform.with_named_sequence } { } transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) { - transform.foreach_match in %arg0 + transform.foreach_match restrict_root in %arg0 @match_structured_suppress -> @do_nothing : (!transform.any_op) -> !transform.any_op transform.yield } + // expected-remark @below {{other}} func.func @payload() attributes { transform.target_tag = "start_here" } { // expected-remark @below {{other}} %D = arith.constant dense<1.0> : tensor<2x4xf32>