Skip to content

[mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass #68330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
@@ -20,10 +20,14 @@ def Transform_Dialect : Dialect {

let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
/// Symbol name for the default entry point "named sequence".
constexpr const static ::llvm::StringLiteral
kTransformEntryPointSymbolName = "__transform_main";

/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
"transform.with_named_sequence";
constexpr const static ::llvm::StringLiteral
kWithNamedSequenceAttrName = "transform.with_named_sequence";

/// Name of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
@@ -74,6 +78,22 @@ def Transform_Dialect : Dialect {
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;

/// Appends the given module as a transform symbol library available to
/// all dialect users.
void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
library) {
libraryModules.push_back(std::move(library));
}

/// Returns a range of registered library modules.
auto getLibraryModules() const {
return ::llvm::map_range(
libraryModules,
[](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
return library.get();
});
}

private:
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
@@ -132,6 +152,11 @@ def Transform_Dialect : Dialect {
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;

/// Modules containing symbols, e.g. named sequences, that will be
/// resolved by the interpreter when used.
::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
libraryModules;
}];
}

5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
@@ -111,7 +111,8 @@ class TransformOptions {
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping = {},
const TransformOptions &options = TransformOptions());
const TransformOptions &options = TransformOptions(),
bool enforceToplevelTransformOp = true);

/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
@@ -193,7 +194,7 @@ class TransformState {

friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
const RaggedArray<MappedValue> &,
const TransformOptions &);
const TransformOptions &, bool);

friend TransformState
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include <memory>

namespace mlir {
struct LogicalResult;
class MLIRContext;
class ModuleOp;
class Operation;
template <typename>
class OwningOpRef;
class Region;

namespace transform {
namespace detail {
/// Utility to parse and verify the content of a `transformFileName` MLIR file
/// containing a transform dialect specification.
LogicalResult
parseTransformModuleFromFile(MLIRContext *context,
llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule);

/// Utility to load a transform interpreter `module` from a module that has
/// already been preloaded in the context.
/// This mode is useful in cases where explicit parsing of a transform library
/// from file is expected to be prohibitively expensive.
/// In such cases, the transform module is expected to be found in the preloaded
/// library modules of the transform dialect.
/// Returns null if the module is not found.
ModuleOp getPreloadedTransformModule(MLIRContext *context);

/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
/// that is either:
/// 1. nested under `root` (takes precedence).
/// 2. nested under `module`, if not found in `root`.
/// Reports errors and returns null if no such operation found.
TransformOpInterface findTransformEntryPoint(
Operation *root, ModuleOp module,
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);

/// Merge all symbols from `other` into `target`. Both ops need to implement the
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
/// modified by this function and might not verify after the function returns.
/// Upon merging, private symbols may be renamed in order to avoid collisions in
/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
// TODO: Reconsider cloning individual ops rather than forcing users of the
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
LogicalResult mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);
} // namespace detail

/// Standalone util to apply the named sequence `entryPoint` to the payload.
/// This is done in 3 steps:
/// 1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
/// calling detail::findTransformEntryPoint.
/// 2. if the entry point is found and not nested under
/// `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
/// the `sharedTransformModule`. Note: this may modify the transform IR
/// embedded with the payload IR.
/// 3. apply the transform IR to the payload IR, relaxing the requirement that
/// the transform IR is a top-level transform op. We are applying a named
/// sequence anyway.
LogicalResult applyTransformNamedSequence(
Operation *payload, ModuleOp transformModule,
const TransformOptions &options,
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);

} // namespace transform
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
26 changes: 13 additions & 13 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -2079,20 +2079,20 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
// Entry point.
//===----------------------------------------------------------------------===//

LogicalResult
transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
const TransformOptions &options) {
#ifndef NDEBUG
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
transform->emitError()
<< "expected transform to start at the top-level transform op";
llvm::report_fatal_error("could not run transforms",
/*gen_crash_diag=*/false);
LogicalResult transform::applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
const TransformOptions &options, bool enforceToplevelTransformOp) {
if (enforceToplevelTransformOp) {
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
return transform->emitError()
<< "expected transform to start at the top-level transform op";
}
} else if (failed(
detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
return failure();
}
#endif // NDEBUG

TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
InferEffects.cpp
TransformInterpreterPassBase.cpp
TransformInterpreterUtils.cpp

DEPENDS
MLIRTransformDialectTransformsIncGen
284 changes: 12 additions & 272 deletions mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
@@ -51,34 +52,6 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
constexpr static llvm::StringLiteral
kTransformDialectTagTransformContainerValue = "transform_container";

/// Utility to parse the content of a `transformFileName` MLIR file containing
/// a transform dialect specification.
static LogicalResult
parseTransformModuleFromFile(MLIRContext *context,
llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule) {
if (transformFileName.empty()) {
LLVM_DEBUG(
DBGS() << "no transform file name specified, assuming the transform "
"module is embedded in the IR next to the top-level\n");
return success();
}
// Parse transformFileName content into a ModuleOp.
std::string errorMessage;
auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
if (!memoryBuffer) {
return emitError(FileLineColLoc::get(
StringAttr::get(context, transformFileName), 0, 0))
<< "failed to open transform file: " << errorMessage;
}
// Tell sourceMgr about this buffer, the parser will pick it up.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
transformModule =
OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
return success();
}

/// Finds the single top-level transform operation with `root` as ancestor.
/// Reports an error if there is more than one such operation and returns the
/// first one found. Reports an error returns nullptr if no such operation
@@ -295,239 +268,6 @@ static void performOptionalDebugActions(
transform->removeAttr(kTransformDialectTagAttrName);
}

/// Return whether `func1` can be merged into `func2`. For that to work `func1`
/// has to be a declaration (aka has to be external) and `func2` either has to
/// be a declaration as well, or it has to be public (otherwise, it wouldn't
/// be visible by `func1`).
static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
}

/// Merge `func1` into `func2`. The two ops must be inside the same parent op
/// and mergable according to `canMergeInto`. The function erases `func1` such
/// that only `func2` exists when the function returns.
static LogicalResult mergeInto(FunctionOpInterface func1,
FunctionOpInterface func2) {
assert(canMergeInto(func1, func2));
assert(func1->getParentOp() == func2->getParentOp() &&
"expected func1 and func2 to be in the same parent op");

// Check that function signatures match.
if (func1.getFunctionType() != func2.getFunctionType()) {
return func1.emitError()
<< "external definition has a mismatching signature ("
<< func2.getFunctionType() << ")";
}

// Check and merge argument attributes.
MLIRContext *context = func1->getContext();
auto *td = context->getLoadedDialect<transform::TransformDialect>();
StringAttr consumedName = td->getConsumedAttrName();
StringAttr readOnlyName = td->getReadOnlyAttrName();
for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
if (!isExternalConsumed && !isExternalReadonly) {
if (isConsumed)
func2.setArgAttr(i, consumedName, UnitAttr::get(context));
else if (isReadonly)
func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
continue;
}

if ((isExternalConsumed && !isConsumed) ||
(isExternalReadonly && !isReadonly)) {
return func1.emitError()
<< "external definition has mismatching consumption "
"annotations for argument #"
<< i;
}
}

// `func1` is the external one, so we can remove it.
assert(func1.isExternal());
func1->erase();

return success();
}

/// Merge all symbols from `other` into `target`. Both ops need to implement the
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
/// modified by this function and might not verify after the function returns.
/// Upon merging, private symbols may be renamed in order to avoid collisions in
/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
// TODO: Reconsider cloning individual ops rather than forcing users of the
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
static LogicalResult mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other) {
assert(target->hasTrait<OpTrait::SymbolTable>() &&
"requires target to implement the 'SymbolTable' trait");
assert(other->hasTrait<OpTrait::SymbolTable>() &&
"requires target to implement the 'SymbolTable' trait");

SymbolTable targetSymbolTable(target);
SymbolTable otherSymbolTable(*other);

// Step 1:
//
// Rename private symbols in both ops in order to resolve conflicts that can
// be resolved that way.
LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
// TODO: Do we *actually* need to test in both directions?
for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
SmallVector<SymbolTable *, 2>{&otherSymbolTable,
&targetSymbolTable})) {
Operation *symbolTableOp = symbolTable->getOp();
for (Operation &op : symbolTableOp->getRegion(0).front()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
StringAttr name = symbolOp.getNameAttr();
LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");

// Check if there is a colliding op in the other module.
auto collidingOp =
cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
if (!collidingOp)
continue;

LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());

// Collisions are fine if both opt are functions and can be merged.
if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
collidingFuncOp =
dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
funcOp && collidingFuncOp) {
if (canMergeInto(funcOp, collidingFuncOp) ||
canMergeInto(collidingFuncOp, funcOp)) {
LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
"will be merged\n");
continue;
}

// If they can't be merged, proceed like any other collision.
LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
}

// Collision can be resolved by renaming if one of the ops is private.
auto renameToUnique =
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
SymbolTable &otherSymbolTable) -> LogicalResult {
LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
if (failed(maybeNewName)) {
InFlightDiagnostic diag = op->emitError("failed to rename symbol");
diag.attachNote(otherOp->getLoc())
<< "attempted renaming due to collision with this op";
return diag;
}
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
<< "\n");
return success();
};

if (symbolOp.isPrivate()) {
if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
*otherSymbolTable)))
return failure();
continue;
}
if (collidingOp.isPrivate()) {
if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
*symbolTable)))
return failure();
continue;
}

LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
InFlightDiagnostic diag = symbolOp.emitError()
<< "doubly defined symbol @" << name.getValue();
diag.attachNote(collidingOp->getLoc()) << "previously defined here";
return diag;
}
}

// TODO: This duplicates pass infrastructure. We should split this pass into
// several and let the pass infrastructure do the verification.
for (auto *op : SmallVector<Operation *>{target, *other}) {
if (failed(mlir::verify(op)))
return op->emitError() << "failed to verify input op after renaming";
}

// Step 2:
//
// Move all ops from `other` into target and merge public symbols.
LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
{
SmallVector<SymbolOpInterface> opsToMove;
for (Operation &op : other->getRegion(0).front()) {
if (auto symbol = dyn_cast<SymbolOpInterface>(op))
opsToMove.push_back(symbol);
}

for (SymbolOpInterface op : opsToMove) {
// Remember potentially colliding op in the target module.
auto collidingOp = cast_or_null<SymbolOpInterface>(
targetSymbolTable.lookup(op.getNameAttr()));

// Move op even if we get a collision.
LLVM_DEBUG(DBGS() << " moving @" << op.getName());
op->moveBefore(&target->getRegion(0).front(),
target->getRegion(0).front().end());

// If there is no collision, we are done.
if (!collidingOp) {
LLVM_DEBUG(llvm::dbgs() << " without collision\n");
continue;
}

// The two colliding ops must both be functions because we have already
// emitted errors otherwise earlier.
auto funcOp = cast<FunctionOpInterface>(op.getOperation());
auto collidingFuncOp =
cast<FunctionOpInterface>(collidingOp.getOperation());

// Both ops are in the target module now and can be treated symmetrically,
// so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
if (!canMergeInto(funcOp, collidingFuncOp)) {
std::swap(funcOp, collidingFuncOp);
}
assert(canMergeInto(funcOp, collidingFuncOp));

LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
<< collidingFuncOp.getLoc() << ":\n"
<< collidingFuncOp << "\n");

// Update symbol table. This works with or without the previous `swap`.
targetSymbolTable.remove(funcOp);
targetSymbolTable.insert(collidingFuncOp);
assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);

// Do the actual merging.
if (failed(mergeInto(funcOp, collidingFuncOp))) {
return failure();
}
}
}

if (failed(mlir::verify(target)))
return target->emitError()
<< "failed to verify target op after merging symbols";

LLVM_DEBUG(DBGS() << "done merging ops\n");
return success();
}

LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
Operation *target, StringRef passName,
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
@@ -595,9 +335,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
diag.attachNote(target->getLoc()) << "pass anchor op";
return diag;
}
if (failed(
mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
transformLibraryModule->get()->clone())))
if (failed(detail::mergeSymbolsInto(
SymbolTable::getNearestSymbolTable(transformRoot),
transformLibraryModule->get()->clone())))
return emitError(transformRoot->getLoc(),
"failed to merge library symbols into transform root");
}
@@ -683,8 +423,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
OwningOpRef<ModuleOp> moduleFromFile;
{
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
if (failed(parseTransformModuleFromFile(context, transformFileName,
moduleFromFile)))
if (failed(detail::parseTransformModuleFromFile(context, transformFileName,
moduleFromFile)))
return emitError(loc) << "failed to parse transform module";
if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
return emitError(loc) << "failed to verify transform module";
@@ -701,8 +441,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
for (const std::string &libraryFileName : libraryFileNames) {
OwningOpRef<ModuleOp> parsedLibrary;
auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
if (failed(parseTransformModuleFromFile(context, libraryFileName,
parsedLibrary)))
if (failed(detail::parseTransformModuleFromFile(context, libraryFileName,
parsedLibrary)))
return emitError(loc) << "failed to parse transform library module";
if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
return emitError(loc) << "failed to verify transform library module";
@@ -741,8 +481,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
IRRewriter rewriter(context);
// TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
if (failed(mergeSymbolsInto(mergedParsedLibraries.get(),
std::move(parsedLibrary))))
if (failed(detail::mergeSymbolsInto(mergedParsedLibraries.get(),
std::move(parsedLibrary))))
return mergedParsedLibraries->emitError()
<< "failed to verify merged transform module";
}
@@ -751,8 +491,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
// Use parsed libaries to resolve symbols in shared transform module or return
// as separate library module.
if (sharedTransformModule && *sharedTransformModule) {
if (failed(mergeSymbolsInto(sharedTransformModule->get(),
std::move(mergedParsedLibraries))))
if (failed(detail::mergeSymbolsInto(sharedTransformModule->get(),
std::move(mergedParsedLibraries))))
return (*sharedTransformModule)->emitError()
<< "failed to merge symbols from library files "
"into shared transform module";
337 changes: 337 additions & 0 deletions mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
//===- TransformInterpreterUtils.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Lightweight transform dialect interpreter utilities.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;

#define DEBUG_TYPE "transform-dialect-interpreter-utils"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

LogicalResult transform::detail::parseTransformModuleFromFile(
MLIRContext *context, llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule) {
if (transformFileName.empty()) {
LLVM_DEBUG(
DBGS() << "no transform file name specified, assuming the transform "
"module is embedded in the IR next to the top-level\n");
return success();
}
// Parse transformFileName content into a ModuleOp.
std::string errorMessage;
auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage);
if (!memoryBuffer) {
return emitError(FileLineColLoc::get(
StringAttr::get(context, transformFileName), 0, 0))
<< "failed to open transform file: " << errorMessage;
}
// Tell sourceMgr about this buffer, the parser will pick it up.
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc());
transformModule =
OwningOpRef<ModuleOp>(parseSourceFile<ModuleOp>(sourceMgr, context));
return mlir::verify(*transformModule);
}

ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
auto preloadedLibraryRange =
context->getOrLoadDialect<transform::TransformDialect>()
->getLibraryModules();
if (!preloadedLibraryRange.empty())
return *preloadedLibraryRange.begin();
return ModuleOp();
}

transform::TransformOpInterface
transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
StringRef entryPoint) {
SmallVector<Operation *, 2> l{root};
if (module)
l.push_back(module);
for (Operation *op : l) {
transform::TransformOpInterface transform = nullptr;
op->walk<WalkOrder::PreOrder>(
[&](transform::NamedSequenceOp namedSequenceOp) {
if (namedSequenceOp.getSymName() == entryPoint) {
transform = cast<transform::TransformOpInterface>(
namedSequenceOp.getOperation());
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (transform)
return transform;
}
auto diag = root->emitError()
<< "could not find a nested named sequence with name: "
<< entryPoint;
return nullptr;
}

/// Return whether `func1` can be merged into `func2`. For that to work `func1`
/// has to be a declaration (aka has to be external) and `func2` either has to
/// be a declaration as well, or it has to be public (otherwise, it wouldn't
/// be visible by `func1`).
static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
}

/// Merge `func1` into `func2`. The two ops must be inside the same parent op
/// and mergable according to `canMergeInto`. The function erases `func1` such
/// that only `func2` exists when the function returns.
static LogicalResult mergeInto(FunctionOpInterface func1,
FunctionOpInterface func2) {
assert(canMergeInto(func1, func2));
assert(func1->getParentOp() == func2->getParentOp() &&
"expected func1 and func2 to be in the same parent op");

// Check that function signatures match.
if (func1.getFunctionType() != func2.getFunctionType()) {
return func1.emitError()
<< "external definition has a mismatching signature ("
<< func2.getFunctionType() << ")";
}

// Check and merge argument attributes.
MLIRContext *context = func1->getContext();
auto *td = context->getLoadedDialect<transform::TransformDialect>();
StringAttr consumedName = td->getConsumedAttrName();
StringAttr readOnlyName = td->getReadOnlyAttrName();
for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
if (!isExternalConsumed && !isExternalReadonly) {
if (isConsumed)
func2.setArgAttr(i, consumedName, UnitAttr::get(context));
else if (isReadonly)
func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
continue;
}

if ((isExternalConsumed && !isConsumed) ||
(isExternalReadonly && !isReadonly)) {
return func1.emitError()
<< "external definition has mismatching consumption "
"annotations for argument #"
<< i;
}
}

// `func1` is the external one, so we can remove it.
assert(func1.isExternal());
func1->erase();

return success();
}

LogicalResult
transform::detail::mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other) {
assert(target->hasTrait<OpTrait::SymbolTable>() &&
"requires target to implement the 'SymbolTable' trait");
assert(other->hasTrait<OpTrait::SymbolTable>() &&
"requires target to implement the 'SymbolTable' trait");

SymbolTable targetSymbolTable(target);
SymbolTable otherSymbolTable(*other);

// Step 1:
//
// Rename private symbols in both ops in order to resolve conflicts that can
// be resolved that way.
LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
// TODO: Do we *actually* need to test in both directions?
for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
SmallVector<SymbolTable *, 2>{&otherSymbolTable,
&targetSymbolTable})) {
Operation *symbolTableOp = symbolTable->getOp();
for (Operation &op : symbolTableOp->getRegion(0).front()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
StringAttr name = symbolOp.getNameAttr();
LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");

// Check if there is a colliding op in the other module.
auto collidingOp =
cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
if (!collidingOp)
continue;

LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());

// Collisions are fine if both opt are functions and can be merged.
if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
collidingFuncOp =
dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
funcOp && collidingFuncOp) {
if (canMergeInto(funcOp, collidingFuncOp) ||
canMergeInto(collidingFuncOp, funcOp)) {
LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
"will be merged\n");
continue;
}

// If they can't be merged, proceed like any other collision.
LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
}

// Collision can be resolved by renaming if one of the ops is private.
auto renameToUnique =
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
SymbolTable &otherSymbolTable) -> LogicalResult {
LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
if (failed(maybeNewName)) {
InFlightDiagnostic diag = op->emitError("failed to rename symbol");
diag.attachNote(otherOp->getLoc())
<< "attempted renaming due to collision with this op";
return diag;
}
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
<< "\n");
return success();
};

if (symbolOp.isPrivate()) {
if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
*otherSymbolTable)))
return failure();
continue;
}
if (collidingOp.isPrivate()) {
if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
*symbolTable)))
return failure();
continue;
}
LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
InFlightDiagnostic diag = symbolOp.emitError()
<< "doubly defined symbol @" << name.getValue();
diag.attachNote(collidingOp->getLoc()) << "previously defined here";
return diag;
}
}

// TODO: This duplicates pass infrastructure. We should split this pass into
// several and let the pass infrastructure do the verification.
for (auto *op : SmallVector<Operation *>{target, *other}) {
if (failed(mlir::verify(op)))
return op->emitError() << "failed to verify input op after renaming";
}

// Step 2:
//
// Move all ops from `other` into target and merge public symbols.
LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
{
SmallVector<SymbolOpInterface> opsToMove;
for (Operation &op : other->getRegion(0).front()) {
if (auto symbol = dyn_cast<SymbolOpInterface>(op))
opsToMove.push_back(symbol);
}

for (SymbolOpInterface op : opsToMove) {
// Remember potentially colliding op in the target module.
auto collidingOp = cast_or_null<SymbolOpInterface>(
targetSymbolTable.lookup(op.getNameAttr()));

// Move op even if we get a collision.
LLVM_DEBUG(DBGS() << " moving @" << op.getName());
op->moveBefore(&target->getRegion(0).front(),
target->getRegion(0).front().end());

// If there is no collision, we are done.
if (!collidingOp) {
LLVM_DEBUG(llvm::dbgs() << " without collision\n");
continue;
}

// The two colliding ops must both be functions because we have already
// emitted errors otherwise earlier.
auto funcOp = cast<FunctionOpInterface>(op.getOperation());
auto collidingFuncOp =
cast<FunctionOpInterface>(collidingOp.getOperation());

// Both ops are in the target module now and can be treated symmetrically,
// so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
if (!canMergeInto(funcOp, collidingFuncOp)) {
std::swap(funcOp, collidingFuncOp);
}
assert(canMergeInto(funcOp, collidingFuncOp));

LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
<< collidingFuncOp.getLoc() << ":\n"
<< collidingFuncOp << "\n");

// Update symbol table. This works with or without the previous `swap`.
targetSymbolTable.remove(funcOp);
targetSymbolTable.insert(collidingFuncOp);
assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);

// Do the actual merging.
if (failed(mergeInto(funcOp, collidingFuncOp))) {
return failure();
}
}
}

if (failed(mlir::verify(target)))
return target->emitError()
<< "failed to verify target op after merging symbols";

LLVM_DEBUG(DBGS() << "done merging ops\n");
return success();
}

LogicalResult transform::applyTransformNamedSequence(
Operation *payload, ModuleOp transformModule,
const TransformOptions &options, StringRef entryPoint) {
Operation *transformRoot =
detail::findTransformEntryPoint(payload, transformModule, entryPoint);
if (!transformRoot)
return failure();

// `transformModule` may not be modified.
OwningOpRef<Operation *> clonedTransformModule(transformModule->clone());
if (transformModule && !transformModule->isAncestor(transformRoot)) {
if (failed(detail::mergeSymbolsInto(
SymbolTable::getNearestSymbolTable(transformRoot),
std::move(clonedTransformModule))))
return failure();
}

// Apply the transform to the IR, do not enforce top-level constraints.
RaggedArray<MappedValue> noExtraMappings;
return applyTransforms(payload, cast<TransformOpInterface>(transformRoot),
noExtraMappings, options,
/*enforceToplevelTransformOp=*/false);
}
3 changes: 3 additions & 0 deletions mlir/unittests/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
add_mlir_unittest(MLIRTransformDialectTests
BuildOnlyExtensionTest.cpp
Preload.cpp
)
target_link_libraries(MLIRTransformDialectTests
PRIVATE
MLIRFuncDialect
MLIRTestTransformDialect
MLIRTransformDialect
MLIRTransformDialectTransforms
)
92 changes: 92 additions & 0 deletions mlir/unittests/Dialect/Transform/Preload.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//===- Preload.cpp - Test MlirOptMain parameterization ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/raw_ostream.h"
#include "gtest/gtest.h"

using namespace mlir;

namespace mlir {
namespace test {
std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
} // namespace test
} // namespace mlir
namespace test {
void registerTestTransformDialectExtension(DialectRegistry &registry);
} // namespace test

const static llvm::StringLiteral library = R"MLIR(
module attributes {transform.with_named_sequence} {
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op
transform.yield
}
})MLIR";

const static llvm::StringLiteral input = R"MLIR(
module attributes {transform.with_named_sequence} {
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
}
})MLIR";

TEST(Preload, ContextPreloadConstructedLibrary) {
registerPassManagerCLOptions();

MLIRContext context;
auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
DialectRegistry registry;
::test::registerTestTransformDialectExtension(registry);
registry.applyExtensions(&context);
ParserConfig parserConfig(&context);

OwningOpRef<ModuleOp> inputModule =
parseSourceString<ModuleOp>(input, parserConfig, "<input>");
EXPECT_TRUE(inputModule) << "failed to parse input module";

OwningOpRef<ModuleOp> transformLibrary =
parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
dialect->registerLibraryModule(std::move(transformLibrary));

ModuleOp retrievedTransformLibrary =
transform::detail::getPreloadedTransformModule(&context);
EXPECT_TRUE(retrievedTransformLibrary)
<< "failed to retrieve transform module";

transform::TransformOpInterface entryPoint =
transform::detail::findTransformEntryPoint(inputModule->getOperation(),
retrievedTransformLibrary);
EXPECT_TRUE(entryPoint) << "failed to find entry point";

OwningOpRef<Operation *> clonedTransformModule(
retrievedTransformLibrary->clone());
LogicalResult res = transform::detail::mergeSymbolsInto(
inputModule->getOperation(), std::move(clonedTransformModule));
EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";

transform::TransformOptions options;
res = transform::applyTransformNamedSequence(
inputModule->getOperation(), retrievedTransformLibrary, options);
EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
}