Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,23 @@ class TransformOptions {
return *this;
}

// Ensures that only a single top-level transform op is present in the IR.
TransformOptions &enableEnforceSingleToplevelTransformOp(bool enable = true) {
enforceSingleToplevelTransformOp = enable;
return *this;
}

/// Returns true if the expensive checks are requested.
bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }

// Returns true if enforcing a single top-level transform op is requested.
bool getEnforceSingleToplevelTransformOp() const {
return enforceSingleToplevelTransformOp;
}

private:
bool expensiveChecksEnabled = true;
bool enforceSingleToplevelTransformOp = true;
};

/// Entry point to the Transform dialect infrastructure. Applies the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ constexpr static llvm::StringLiteral
/// Reports an error if there is more than one such operation and returns the
/// first one found. Reports an error returns nullptr if no such operation
/// found.
static Operation *findTopLevelTransform(Operation *root,
StringRef filenameOption) {
static Operation *
findTopLevelTransform(Operation *root, StringRef filenameOption,
mlir::transform::TransformOptions options) {
::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
root->walk<WalkOrder::PreOrder>(
[&](::mlir::transform::TransformOpInterface transformOp) {
if (!transformOp
->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
Expand All @@ -68,14 +69,15 @@ static Operation *findTopLevelTransform(Operation *root,
topLevelTransform = transformOp;
return WalkResult::skip();
}
auto diag = transformOp.emitError()
<< "more than one top-level transform op";
diag.attachNote(topLevelTransform.getLoc())
<< "previous top-level transform op";
return WalkResult::interrupt();
if (options.getEnforceSingleToplevelTransformOp()) {
auto diag = transformOp.emitError()
<< "more than one top-level transform op";
diag.attachNote(topLevelTransform.getLoc())
<< "previous top-level transform op";
return WalkResult::interrupt();
}
return WalkResult::skip();
});
if (walkResult.wasInterrupted())
return nullptr;
if (!topLevelTransform) {
auto diag = root->emitError()
<< "could not find a nested top-level transform op";
Expand Down Expand Up @@ -310,7 +312,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
Operation *transformRoot =
debugTransformRootTag.empty()
? findTopLevelTransform(transformContainer,
transformFileName.getArgStr())
transformFileName.getArgStr(), options)
: findOpWithTag(transformContainer, kTransformDialectTagAttrName,
debugTransformRootTag);
if (!transformRoot)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter='enforce-single-top-level-transform-op=0' -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s

transform.sequence failures(propagate) {
// CHECK: transform.sequence
^bb0(%arg0: !transform.any_op):
}

transform.sequence failures(propagate) {
// CHECK: transform.sequence
^bb0(%arg0: !transform.any_op):
}

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%match = transform.structured.match ops{["transform.get_parent_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.test_print_remark_at_operand %match, "found get_parent_op" : !transform.any_op
}

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%op = transform.structured.match ops{[]} in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-remark @below{{found get_parent_op}}
%1 = transform.get_parent_op %op : (!transform.any_op) -> !transform.any_op
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ class TestTransformDialectInterpreterPass
}

options = options.enableExpensiveChecks(enableExpensiveChecks);
options = options.enableEnforceSingleToplevelTransformOp(
enforceSingleToplevelTransformOp);
if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
getOperation(), getArgument(), getSharedTransformModule(),
getTransformLibraryModule(), extraMapping, options,
Expand All @@ -170,6 +172,10 @@ class TestTransformDialectInterpreterPass
*this, "enable-expensive-checks", llvm::cl::init(false),
llvm::cl::desc("perform expensive checks to better report errors in the "
"transform IR")};
Option<bool> enforceSingleToplevelTransformOp{
*this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
llvm::cl::desc("Ensure that only a single top-level transform op is "
"present in the IR.")};

Option<std::string> bindFirstExtraToOps{
*this, "bind-first-extra-to-ops",
Expand Down