-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass #68330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
nicolasvasilache
merged 1 commit into
llvm:main
from
nicolasvasilache:flush-simplify-transforms-2
Oct 6, 2023
+577
−289
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H | ||
#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H | ||
|
||
#include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include <memory> | ||
|
||
namespace mlir { | ||
struct LogicalResult; | ||
class MLIRContext; | ||
class ModuleOp; | ||
class Operation; | ||
template <typename> | ||
class OwningOpRef; | ||
class Region; | ||
|
||
namespace transform { | ||
namespace detail { | ||
/// Utility to parse and verify the content of a `transformFileName` MLIR file | ||
/// containing a transform dialect specification. | ||
LogicalResult | ||
parseTransformModuleFromFile(MLIRContext *context, | ||
llvm::StringRef transformFileName, | ||
OwningOpRef<ModuleOp> &transformModule); | ||
|
||
/// Utility to load a transform interpreter `module` from a module that has | ||
/// already been preloaded in the context. | ||
/// This mode is useful in cases where explicit parsing of a transform library | ||
/// from file is expected to be prohibitively expensive. | ||
/// In such cases, the transform module is expected to be found in the preloaded | ||
/// library modules of the transform dialect. | ||
/// Returns null if the module is not found. | ||
ModuleOp getPreloadedTransformModule(MLIRContext *context); | ||
|
||
/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName` | ||
/// that is either: | ||
/// 1. nested under `root` (takes precedence). | ||
/// 2. nested under `module`, if not found in `root`. | ||
/// Reports errors and returns null if no such operation found. | ||
TransformOpInterface findTransformEntryPoint( | ||
Operation *root, ModuleOp module, | ||
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName); | ||
|
||
/// Merge all symbols from `other` into `target`. Both ops need to implement the | ||
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be | ||
/// modified by this function and might not verify after the function returns. | ||
/// Upon merging, private symbols may be renamed in order to avoid collisions in | ||
/// the result. Public symbols may not collide, with the exception of | ||
/// instances of `SymbolOpInterface`, where collisions are allowed if at least | ||
/// one of the two is external, in which case the other op preserved (or any one | ||
/// of the two if both are external). | ||
// TODO: Reconsider cloning individual ops rather than forcing users of the | ||
// function to clone (or move) `other` in order to improve efficiency. | ||
// This might primarily make sense if we can also prune the symbols that | ||
// are merged to a subset (such as those that are actually used). | ||
LogicalResult mergeSymbolsInto(Operation *target, | ||
OwningOpRef<Operation *> other); | ||
} // namespace detail | ||
|
||
/// Standalone util to apply the named sequence `entryPoint` to the payload. | ||
/// This is done in 3 steps: | ||
/// 1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by | ||
/// calling detail::findTransformEntryPoint. | ||
/// 2. if the entry point is found and not nested under | ||
/// `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in | ||
/// the `sharedTransformModule`. Note: this may modify the transform IR | ||
/// embedded with the payload IR. | ||
/// 3. apply the transform IR to the payload IR, relaxing the requirement that | ||
/// the transform IR is a top-level transform op. We are applying a named | ||
/// sequence anyway. | ||
LogicalResult applyTransformNamedSequence( | ||
Operation *payload, ModuleOp transformModule, | ||
const TransformOptions &options, | ||
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName); | ||
|
||
} // namespace transform | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
337 changes: 337 additions & 0 deletions
337
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,337 @@ | ||
//===- TransformInterpreterUtils.cpp --------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Lightweight transform dialect interpreter utilities. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" | ||
#include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" | ||
#include "mlir/Dialect/Transform/IR/TransformOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/Verifier.h" | ||
#include "mlir/IR/Visitors.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
#include "mlir/Parser/Parser.h" | ||
#include "mlir/Support/FileUtilities.h" | ||
#include "llvm/ADT/StringRef.h" | ||
#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "llvm/Support/SourceMgr.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
|
||
using namespace mlir; | ||
|
||
#define DEBUG_TYPE "transform-dialect-interpreter-utils" | ||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") | ||
|
||
LogicalResult transform::detail::parseTransformModuleFromFile( | ||
MLIRContext *context, llvm::StringRef transformFileName, | ||
OwningOpRef<ModuleOp> &transformModule) { | ||
if (transformFileName.empty()) { | ||
LLVM_DEBUG( | ||
DBGS() << "no transform file name specified, assuming the transform " | ||
"module is embedded in the IR next to the top-level\n"); | ||
return success(); | ||
} | ||
// Parse transformFileName content into a ModuleOp. | ||
std::string errorMessage; | ||
auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage); | ||
if (!memoryBuffer) { | ||
return emitError(FileLineColLoc::get( | ||
StringAttr::get(context, transformFileName), 0, 0)) | ||
<< "failed to open transform file: " << errorMessage; | ||
} | ||
// Tell sourceMgr about this buffer, the parser will pick it up. | ||
llvm::SourceMgr sourceMgr; | ||
sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); | ||
transformModule = | ||
OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context)); | ||
return mlir::verify(*transformModule); | ||
} | ||
|
||
ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) { | ||
auto preloadedLibraryRange = | ||
context->getOrLoadDialect<transform::TransformDialect>() | ||
->getLibraryModules(); | ||
if (!preloadedLibraryRange.empty()) | ||
return *preloadedLibraryRange.begin(); | ||
return ModuleOp(); | ||
} | ||
|
||
transform::TransformOpInterface | ||
transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module, | ||
StringRef entryPoint) { | ||
SmallVector<Operation *, 2> l{root}; | ||
if (module) | ||
l.push_back(module); | ||
for (Operation *op : l) { | ||
transform::TransformOpInterface transform = nullptr; | ||
op->walk<WalkOrder::PreOrder>( | ||
[&](transform::NamedSequenceOp namedSequenceOp) { | ||
if (namedSequenceOp.getSymName() == entryPoint) { | ||
transform = cast<transform::TransformOpInterface>( | ||
namedSequenceOp.getOperation()); | ||
return WalkResult::interrupt(); | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
if (transform) | ||
return transform; | ||
} | ||
auto diag = root->emitError() | ||
<< "could not find a nested named sequence with name: " | ||
<< entryPoint; | ||
return nullptr; | ||
} | ||
|
||
/// Return whether `func1` can be merged into `func2`. For that to work `func1` | ||
/// has to be a declaration (aka has to be external) and `func2` either has to | ||
/// be a declaration as well, or it has to be public (otherwise, it wouldn't | ||
/// be visible by `func1`). | ||
static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { | ||
return func1.isExternal() && (func2.isPublic() || func2.isExternal()); | ||
} | ||
|
||
/// Merge `func1` into `func2`. The two ops must be inside the same parent op | ||
/// and mergable according to `canMergeInto`. The function erases `func1` such | ||
/// that only `func2` exists when the function returns. | ||
static LogicalResult mergeInto(FunctionOpInterface func1, | ||
FunctionOpInterface func2) { | ||
assert(canMergeInto(func1, func2)); | ||
assert(func1->getParentOp() == func2->getParentOp() && | ||
"expected func1 and func2 to be in the same parent op"); | ||
|
||
// Check that function signatures match. | ||
if (func1.getFunctionType() != func2.getFunctionType()) { | ||
return func1.emitError() | ||
<< "external definition has a mismatching signature (" | ||
<< func2.getFunctionType() << ")"; | ||
} | ||
|
||
// Check and merge argument attributes. | ||
MLIRContext *context = func1->getContext(); | ||
auto *td = context->getLoadedDialect<transform::TransformDialect>(); | ||
StringAttr consumedName = td->getConsumedAttrName(); | ||
StringAttr readOnlyName = td->getReadOnlyAttrName(); | ||
for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { | ||
bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr; | ||
bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr; | ||
bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr; | ||
bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr; | ||
if (!isExternalConsumed && !isExternalReadonly) { | ||
if (isConsumed) | ||
func2.setArgAttr(i, consumedName, UnitAttr::get(context)); | ||
else if (isReadonly) | ||
func2.setArgAttr(i, readOnlyName, UnitAttr::get(context)); | ||
continue; | ||
} | ||
|
||
if ((isExternalConsumed && !isConsumed) || | ||
(isExternalReadonly && !isReadonly)) { | ||
return func1.emitError() | ||
<< "external definition has mismatching consumption " | ||
"annotations for argument #" | ||
<< i; | ||
} | ||
} | ||
|
||
// `func1` is the external one, so we can remove it. | ||
assert(func1.isExternal()); | ||
func1->erase(); | ||
|
||
return success(); | ||
} | ||
|
||
LogicalResult | ||
transform::detail::mergeSymbolsInto(Operation *target, | ||
OwningOpRef<Operation *> other) { | ||
assert(target->hasTrait<OpTrait::SymbolTable>() && | ||
"requires target to implement the 'SymbolTable' trait"); | ||
assert(other->hasTrait<OpTrait::SymbolTable>() && | ||
"requires target to implement the 'SymbolTable' trait"); | ||
|
||
SymbolTable targetSymbolTable(target); | ||
SymbolTable otherSymbolTable(*other); | ||
|
||
// Step 1: | ||
// | ||
// Rename private symbols in both ops in order to resolve conflicts that can | ||
// be resolved that way. | ||
LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); | ||
// TODO: Do we *actually* need to test in both directions? | ||
for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( | ||
SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable}, | ||
SmallVector<SymbolTable *, 2>{&otherSymbolTable, | ||
&targetSymbolTable})) { | ||
Operation *symbolTableOp = symbolTable->getOp(); | ||
for (Operation &op : symbolTableOp->getRegion(0).front()) { | ||
auto symbolOp = dyn_cast<SymbolOpInterface>(op); | ||
if (!symbolOp) | ||
continue; | ||
StringAttr name = symbolOp.getNameAttr(); | ||
LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n"); | ||
|
||
// Check if there is a colliding op in the other module. | ||
auto collidingOp = | ||
cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name)); | ||
if (!collidingOp) | ||
continue; | ||
|
||
LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); | ||
|
||
// Collisions are fine if both opt are functions and can be merged. | ||
if (auto funcOp = dyn_cast<FunctionOpInterface>(op), | ||
collidingFuncOp = | ||
dyn_cast<FunctionOpInterface>(collidingOp.getOperation()); | ||
funcOp && collidingFuncOp) { | ||
if (canMergeInto(funcOp, collidingFuncOp) || | ||
canMergeInto(collidingFuncOp, funcOp)) { | ||
LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " | ||
"will be merged\n"); | ||
continue; | ||
} | ||
|
||
// If they can't be merged, proceed like any other collision. | ||
LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); | ||
} | ||
|
||
// Collision can be resolved by renaming if one of the ops is private. | ||
auto renameToUnique = | ||
[&](SymbolOpInterface op, SymbolOpInterface otherOp, | ||
SymbolTable &symbolTable, | ||
SymbolTable &otherSymbolTable) -> LogicalResult { | ||
LLVM_DEBUG(llvm::dbgs() << ", renaming\n"); | ||
FailureOr<StringAttr> maybeNewName = | ||
symbolTable.renameToUnique(op, {&otherSymbolTable}); | ||
if (failed(maybeNewName)) { | ||
InFlightDiagnostic diag = op->emitError("failed to rename symbol"); | ||
diag.attachNote(otherOp->getLoc()) | ||
<< "attempted renaming due to collision with this op"; | ||
return diag; | ||
} | ||
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() | ||
<< "\n"); | ||
return success(); | ||
}; | ||
|
||
if (symbolOp.isPrivate()) { | ||
if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable, | ||
*otherSymbolTable))) | ||
return failure(); | ||
continue; | ||
} | ||
if (collidingOp.isPrivate()) { | ||
if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable, | ||
*symbolTable))) | ||
return failure(); | ||
continue; | ||
} | ||
LLVM_DEBUG(llvm::dbgs() << ", emitting error\n"); | ||
InFlightDiagnostic diag = symbolOp.emitError() | ||
<< "doubly defined symbol @" << name.getValue(); | ||
diag.attachNote(collidingOp->getLoc()) << "previously defined here"; | ||
return diag; | ||
} | ||
} | ||
|
||
// TODO: This duplicates pass infrastructure. We should split this pass into | ||
// several and let the pass infrastructure do the verification. | ||
for (auto *op : SmallVector<Operation *>{target, *other}) { | ||
if (failed(mlir::verify(op))) | ||
return op->emitError() << "failed to verify input op after renaming"; | ||
} | ||
|
||
// Step 2: | ||
// | ||
// Move all ops from `other` into target and merge public symbols. | ||
LLVM_DEBUG(DBGS() << "moving all symbols into target\n"); | ||
{ | ||
SmallVector<SymbolOpInterface> opsToMove; | ||
for (Operation &op : other->getRegion(0).front()) { | ||
if (auto symbol = dyn_cast<SymbolOpInterface>(op)) | ||
opsToMove.push_back(symbol); | ||
} | ||
|
||
for (SymbolOpInterface op : opsToMove) { | ||
// Remember potentially colliding op in the target module. | ||
auto collidingOp = cast_or_null<SymbolOpInterface>( | ||
targetSymbolTable.lookup(op.getNameAttr())); | ||
|
||
// Move op even if we get a collision. | ||
LLVM_DEBUG(DBGS() << " moving @" << op.getName()); | ||
op->moveBefore(&target->getRegion(0).front(), | ||
target->getRegion(0).front().end()); | ||
|
||
// If there is no collision, we are done. | ||
if (!collidingOp) { | ||
LLVM_DEBUG(llvm::dbgs() << " without collision\n"); | ||
continue; | ||
} | ||
|
||
// The two colliding ops must both be functions because we have already | ||
// emitted errors otherwise earlier. | ||
auto funcOp = cast<FunctionOpInterface>(op.getOperation()); | ||
auto collidingFuncOp = | ||
cast<FunctionOpInterface>(collidingOp.getOperation()); | ||
|
||
// Both ops are in the target module now and can be treated symmetrically, | ||
// so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`. | ||
if (!canMergeInto(funcOp, collidingFuncOp)) { | ||
std::swap(funcOp, collidingFuncOp); | ||
} | ||
assert(canMergeInto(funcOp, collidingFuncOp)); | ||
|
||
LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " | ||
<< collidingFuncOp.getLoc() << ":\n" | ||
<< collidingFuncOp << "\n"); | ||
|
||
// Update symbol table. This works with or without the previous `swap`. | ||
targetSymbolTable.remove(funcOp); | ||
targetSymbolTable.insert(collidingFuncOp); | ||
assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp); | ||
|
||
// Do the actual merging. | ||
if (failed(mergeInto(funcOp, collidingFuncOp))) { | ||
return failure(); | ||
} | ||
} | ||
} | ||
|
||
if (failed(mlir::verify(target))) | ||
return target->emitError() | ||
<< "failed to verify target op after merging symbols"; | ||
|
||
LLVM_DEBUG(DBGS() << "done merging ops\n"); | ||
return success(); | ||
} | ||
|
||
LogicalResult transform::applyTransformNamedSequence( | ||
Operation *payload, ModuleOp transformModule, | ||
const TransformOptions &options, StringRef entryPoint) { | ||
Operation *transformRoot = | ||
detail::findTransformEntryPoint(payload, transformModule, entryPoint); | ||
if (!transformRoot) | ||
return failure(); | ||
|
||
// `transformModule` may not be modified. | ||
OwningOpRef<Operation *> clonedTransformModule(transformModule->clone()); | ||
if (transformModule && !transformModule->isAncestor(transformRoot)) { | ||
if (failed(detail::mergeSymbolsInto( | ||
SymbolTable::getNearestSymbolTable(transformRoot), | ||
std::move(clonedTransformModule)))) | ||
return failure(); | ||
} | ||
|
||
// Apply the transform to the IR, do not enforce top-level constraints. | ||
RaggedArray<MappedValue> noExtraMappings; | ||
return applyTransforms(payload, cast<TransformOpInterface>(transformRoot), | ||
noExtraMappings, options, | ||
/*enforceToplevelTransformOp=*/false); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
add_mlir_unittest(MLIRTransformDialectTests | ||
BuildOnlyExtensionTest.cpp | ||
Preload.cpp | ||
) | ||
target_link_libraries(MLIRTransformDialectTests | ||
PRIVATE | ||
MLIRFuncDialect | ||
MLIRTestTransformDialect | ||
MLIRTransformDialect | ||
MLIRTransformDialectTransforms | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
//===- Preload.cpp - Test MlirOptMain parameterization ------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" | ||
#include "mlir/IR/AsmState.h" | ||
#include "mlir/IR/DialectRegistry.h" | ||
#include "mlir/IR/Verifier.h" | ||
#include "mlir/Parser/Parser.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Support/FileUtilities.h" | ||
#include "mlir/Support/TypeID.h" | ||
#include "mlir/Tools/mlir-opt/MlirOptMain.h" | ||
#include "llvm/Support/MemoryBuffer.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
#include "gtest/gtest.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace mlir { | ||
namespace test { | ||
std::unique_ptr<Pass> createTestTransformDialectInterpreterPass(); | ||
} // namespace test | ||
} // namespace mlir | ||
namespace test { | ||
void registerTestTransformDialectExtension(DialectRegistry ®istry); | ||
} // namespace test | ||
|
||
const static llvm::StringLiteral library = R"MLIR( | ||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) { | ||
transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op | ||
transform.yield | ||
} | ||
})MLIR"; | ||
|
||
const static llvm::StringLiteral input = R"MLIR( | ||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) | ||
transform.sequence failures(propagate) { | ||
^bb0(%arg0: !transform.any_op): | ||
include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> () | ||
} | ||
})MLIR"; | ||
|
||
TEST(Preload, ContextPreloadConstructedLibrary) { | ||
registerPassManagerCLOptions(); | ||
|
||
MLIRContext context; | ||
auto *dialect = context.getOrLoadDialect<transform::TransformDialect>(); | ||
DialectRegistry registry; | ||
::test::registerTestTransformDialectExtension(registry); | ||
registry.applyExtensions(&context); | ||
ParserConfig parserConfig(&context); | ||
|
||
OwningOpRef<ModuleOp> inputModule = | ||
parseSourceString<ModuleOp>(input, parserConfig, "<input>"); | ||
EXPECT_TRUE(inputModule) << "failed to parse input module"; | ||
|
||
OwningOpRef<ModuleOp> transformLibrary = | ||
parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>"); | ||
EXPECT_TRUE(transformLibrary) << "failed to parse transform module"; | ||
dialect->registerLibraryModule(std::move(transformLibrary)); | ||
|
||
ModuleOp retrievedTransformLibrary = | ||
transform::detail::getPreloadedTransformModule(&context); | ||
EXPECT_TRUE(retrievedTransformLibrary) | ||
<< "failed to retrieve transform module"; | ||
|
||
transform::TransformOpInterface entryPoint = | ||
transform::detail::findTransformEntryPoint(inputModule->getOperation(), | ||
retrievedTransformLibrary); | ||
EXPECT_TRUE(entryPoint) << "failed to find entry point"; | ||
|
||
OwningOpRef<Operation *> clonedTransformModule( | ||
retrievedTransformLibrary->clone()); | ||
LogicalResult res = transform::detail::mergeSymbolsInto( | ||
inputModule->getOperation(), std::move(clonedTransformModule)); | ||
EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols"; | ||
|
||
transform::TransformOptions options; | ||
res = transform::applyTransformNamedSequence( | ||
inputModule->getOperation(), retrievedTransformLibrary, options); | ||
EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence"; | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.