Skip to content

Commit 2040139

Browse files
Do not modify the target when loading transform library module.
Until now, if the transform script was embedded into the input IR, the transform dialect interpreter injected the externally resolved symbols into that IR, which then became part of the output. This is not always desirable. This commit is a first step to separate the logic of loading/resolution/ injection from the interpreter. The modification consists of cloning the IR that contains the main transform script if necessary (i.e., if we actually need to load it and it is part of the input op of the pass). The next step will be to introduce a dedicated pass for loading and injecting transform script and or library.
1 parent 8f2800b commit 2040139

File tree

4 files changed

+30
-16
lines changed

4 files changed

+30
-16
lines changed

mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
6464
/// will be interpreted.
6565
/// - transformLibraryFileName: if non-empty, the name of the file containing
6666
/// definitions of external symbols referenced in the transform script.
67-
/// These definitions will be used to replace declarations.
67+
/// These definitions will be used to resolve declarations.
6868
/// - debugPayloadRootTag: if non-empty, the value of the attribute named
6969
/// `kTransformDialectTagAttrName` indicating the single op that is
7070
/// considered the payload root of the transform interpreter; otherwise, the
@@ -85,7 +85,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
8585
/// as template arguments. They are *not* expected to to implement `initialize`
8686
/// or `runOnOperation`. They *are* expected to call the copy constructor of
8787
/// this class in their copy constructors, short of which the file-based
88-
/// transform dialect script injection facility will become nonoperational.
88+
/// transform dialect script resolution facility will become non-operational.
8989
///
9090
/// Concrete passes may implement the `runBeforeInterpreter` and
9191
/// `runAfterInterpreter` to customize the behavior of the pass.

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,24 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
413413
// transform is embedded in the payload IR. If debugTransformRootTag was
414414
// passed, then we are in user-specified selection of the transforming IR.
415415
// This corresponds to REPL debug mode.
416-
Operation *transformContainer =
417-
hasSharedTransformModule ? sharedTransformModule->get() : target;
416+
417+
OwningOpRef<Operation *> transformContainerClone;
418+
Operation *transformContainer;
419+
if (hasTransformLibraryModule) {
420+
// If we have a library module, then the transform script is embedded in the
421+
// target, which we don't want to modify when loading the library. We thus
422+
// clone the target and use that as transform container.
423+
assert(!hasSharedTransformModule);
424+
transformContainerClone = target->clone();
425+
transformContainer = transformContainerClone.get();
426+
} else {
427+
// If we have a shared library, which is private to us, we can modify it
428+
// when loading the library, so we use that. Otherwise, we don't have any
429+
// library to load, so we can use the target and won't modify it.
430+
transformContainer =
431+
hasSharedTransformModule ? sharedTransformModule->get() : target;
432+
}
433+
418434
Operation *transformRoot =
419435
debugTransformRootTag.empty()
420436
? findTopLevelTransform(transformContainer,
@@ -436,7 +452,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
436452
// concurrent execution (normally, the error shouldn't be triggered unless the
437453
// transform IR modifies itself in a pass, which is also forbidden elsewhere).
438454
if (hasTransformLibraryModule) {
439-
if (!target->isProperAncestor(transformRoot)) {
455+
if (!transformContainer->isProperAncestor(transformRoot)) {
440456
InFlightDiagnostic diag =
441457
transformRoot->emitError()
442458
<< "cannot inject transform definitions next to pass anchor op";

mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,25 @@
11
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
22
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
33

4-
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
5-
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
6-
74
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
85
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
96

107
// The definition of the @foo named sequence is provided in another file. It
11-
// will be included because of the pass option. Repeated application of the
12-
// same pass, with or without the library option, should not be a problem.
8+
// will be available because of the pass option but not included in the output.
9+
// Repeated application of the same pass works, but only if the library is
10+
// provided in both.
1311
// Note that the same diagnostic produced twice at the same location only
1412
// needs to be matched once.
1513

1614
// expected-remark @below {{message}}
1715
// expected-remark @below {{unannotated}}
1816
module attributes {transform.with_named_sequence} {
19-
// CHECK: transform.named_sequence @foo
20-
// CHECK: test_print_remark_at_operand %{{.*}}, "message"
17+
// CHECK: transform.named_sequence private @foo
18+
// CHECK-NOT: test_print_remark_at_operand
2119
transform.named_sequence private @foo(!transform.any_op {transform.readonly})
2220

23-
// CHECK: transform.named_sequence @unannotated
24-
// CHECK: test_print_remark_at_operand %{{.*}}, "unannotated"
21+
// CHECK: transform.named_sequence private @unannotated
22+
// CHECK-NOT: test_print_remark_at_operand
2523
transform.named_sequence private @unannotated(!transform.any_op {transform.readonly})
2624

2725
transform.sequence failures(propagate) {

mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ class TestTransformDialectInterpreterPass
219219
Option<std::string> transformLibraryFileName{
220220
*this, "transform-library-file-name", llvm::cl::init(""),
221221
llvm::cl::desc(
222-
"Optional name of the file containing transform dialect symbol "
223-
"definitions to be injected into the transform module.")};
222+
"Optional name of the file providing transform dialect definitions "
223+
"from which declarations in the transform module can be resolved.")};
224224

225225
Option<bool> testModuleGeneration{
226226
*this, "test-module-generation", llvm::cl::init(false),

0 commit comments

Comments
 (0)