diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h index 91903e254b0d5..9c67f3af61cc1 100644 --- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h @@ -64,7 +64,7 @@ LogicalResult interpreterBaseRunOnOperationImpl( /// will be interpreted. /// - transformLibraryFileName: if non-empty, the name of the file containing /// definitions of external symbols referenced in the transform script. -/// These definitions will be used to replace declarations. +/// These definitions will be used to resolve declarations. /// - debugPayloadRootTag: if non-empty, the value of the attribute named /// `kTransformDialectTagAttrName` indicating the single op that is /// considered the payload root of the transform interpreter; otherwise, the @@ -85,7 +85,7 @@ LogicalResult interpreterBaseRunOnOperationImpl( /// as template arguments. They are *not* expected to to implement `initialize` /// or `runOnOperation`. They *are* expected to call the copy constructor of /// this class in their copy constructors, short of which the file-based -/// transform dialect script injection facility will become nonoperational. +/// transform dialect script resolution facility will become non-operational. /// /// Concrete passes may implement the `runBeforeInterpreter` and /// `runAfterInterpreter` to customize the behavior of the pass. diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 23640c92457a8..db856e9c973db 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -413,8 +413,24 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( // transform is embedded in the payload IR. If debugTransformRootTag was // passed, then we are in user-specified selection of the transforming IR. // This corresponds to REPL debug mode. - Operation *transformContainer = - hasSharedTransformModule ? sharedTransformModule->get() : target; + + OwningOpRef transformContainerClone; + Operation *transformContainer; + if (hasTransformLibraryModule) { + // If we have a library module, then the transform script is embedded in the + // target, which we don't want to modify when loading the library. We thus + // clone the target and use that as transform container. + assert(!hasSharedTransformModule); + transformContainerClone = target->clone(); + transformContainer = transformContainerClone.get(); + } else { + // If we have a shared library, which is private to us, we can modify it + // when loading the library, so we use that. Otherwise, we don't have any + // library to load, so we can use the target and won't modify it. + transformContainer = + hasSharedTransformModule ? sharedTransformModule->get() : target; + } + Operation *transformRoot = debugTransformRootTag.empty() ? findTopLevelTransform(transformContainer, @@ -436,7 +452,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( // concurrent execution (normally, the error shouldn't be triggered unless the // transform IR modifies itself in a pass, which is also forbidden elsewhere). if (hasTransformLibraryModule) { - if (!target->isProperAncestor(transformRoot)) { + if (!transformContainer->isProperAncestor(transformRoot)) { InFlightDiagnostic diag = transformRoot->emitError() << "cannot inject transform definitions next to pass anchor op"; diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir index 04b6c5a02e0ad..076a217109480 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir @@ -1,27 +1,25 @@ // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ // RUN: --verify-diagnostics --split-input-file | FileCheck %s -// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ -// RUN: --verify-diagnostics --split-input-file | FileCheck %s - // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ // RUN: --verify-diagnostics --split-input-file | FileCheck %s // The definition of the @foo named sequence is provided in another file. It -// will be included because of the pass option. Repeated application of the -// same pass, with or without the library option, should not be a problem. +// will be available because of the pass option but not included in the output. +// Repeated application of the same pass works, but only if the library is +// provided in both. // Note that the same diagnostic produced twice at the same location only // needs to be matched once. // expected-remark @below {{message}} // expected-remark @below {{unannotated}} module attributes {transform.with_named_sequence} { - // CHECK: transform.named_sequence @foo - // CHECK: test_print_remark_at_operand %{{.*}}, "message" + // CHECK: transform.named_sequence private @foo + // CHECK-NOT: test_print_remark_at_operand transform.named_sequence private @foo(!transform.any_op {transform.readonly}) - // CHECK: transform.named_sequence @unannotated - // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated" + // CHECK: transform.named_sequence private @unannotated + // CHECK-NOT: test_print_remark_at_operand transform.named_sequence private @unannotated(!transform.any_op {transform.readonly}) transform.sequence failures(propagate) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp index f73deef9d5fd4..675b5ecd50346 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -219,8 +219,8 @@ class TestTransformDialectInterpreterPass Option transformLibraryFileName{ *this, "transform-library-file-name", llvm::cl::init(""), llvm::cl::desc( - "Optional name of the file containing transform dialect symbol " - "definitions to be injected into the transform module.")}; + "Optional name of the file providing transform dialect definitions " + "from which declarations in the transform module can be resolved.")}; Option testModuleGeneration{ *this, "test-module-generation", llvm::cl::init(false),