From cdfa54029105b464d72fedd0d47dedae2fbef01d Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 5 Oct 2023 16:11:00 +0000 Subject: [PATCH] [mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass --- .../Dialect/Transform/IR/TransformDialect.td | 29 +- .../Transform/IR/TransformInterfaces.h | 5 +- .../Transforms/TransformInterpreterUtils.h | 89 +++++ .../Transform/IR/TransformInterfaces.cpp | 26 +- .../Transform/Transforms/CMakeLists.txt | 1 + .../TransformInterpreterPassBase.cpp | 284 +-------------- .../Transforms/TransformInterpreterUtils.cpp | 337 ++++++++++++++++++ .../Dialect/Transform/CMakeLists.txt | 3 + mlir/unittests/Dialect/Transform/Preload.cpp | 92 +++++ 9 files changed, 577 insertions(+), 289 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h create mode 100644 mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp create mode 100644 mlir/unittests/Dialect/Transform/Preload.cpp diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index f28205a255070..ad6804673b770 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -20,10 +20,14 @@ def Transform_Dialect : Dialect { let hasOperationAttrVerify = 1; let extraClassDeclaration = [{ + /// Symbol name for the default entry point "named sequence". + constexpr const static ::llvm::StringLiteral + kTransformEntryPointSymbolName = "__transform_main"; + /// Name of the attribute attachable to the symbol table operation /// containing named sequences. This is used to trigger verification. - constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName = - "transform.with_named_sequence"; + constexpr const static ::llvm::StringLiteral + kWithNamedSequenceAttrName = "transform.with_named_sequence"; /// Name of the attribute attachable to an operation so it can be /// identified as root by the default interpreter pass. @@ -74,6 +78,22 @@ def Transform_Dialect : Dialect { using ExtensionTypePrintingHook = std::function; + /// Appends the given module as a transform symbol library available to + /// all dialect users. + void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> && + library) { + libraryModules.push_back(std::move(library)); + } + + /// Returns a range of registered library modules. + auto getLibraryModules() const { + return ::llvm::map_range( + libraryModules, + [](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) { + return library.get(); + }); + } + private: /// Registers operations specified as template parameters with this /// dialect. Checks that they implement the required interfaces. @@ -132,6 +152,11 @@ def Transform_Dialect : Dialect { /// lookups when the type is fully constructed. ::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook> typePrintingHooks; + + /// Modules containing symbols, e.g. named sequences, that will be + /// resolved by the interpreter when used. + ::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2> + libraryModules; }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 0e72a93e685e3..7b37245fc3d11 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -111,7 +111,8 @@ class TransformOptions { LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray &extraMapping = {}, - const TransformOptions &options = TransformOptions()); + const TransformOptions &options = TransformOptions(), + bool enforceToplevelTransformOp = true); /// The state maintained across applications of various ops implementing the /// TransformOpInterface. The operations implementing this interface and the @@ -193,7 +194,7 @@ class TransformState { friend LogicalResult applyTransforms(Operation *, TransformOpInterface, const RaggedArray &, - const TransformOptions &); + const TransformOptions &, bool); friend TransformState detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot); diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h new file mode 100644 index 0000000000000..36c80e6fd61d3 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h @@ -0,0 +1,89 @@ +//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H +#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include + +namespace mlir { +struct LogicalResult; +class MLIRContext; +class ModuleOp; +class Operation; +template +class OwningOpRef; +class Region; + +namespace transform { +namespace detail { +/// Utility to parse and verify the content of a `transformFileName` MLIR file +/// containing a transform dialect specification. +LogicalResult +parseTransformModuleFromFile(MLIRContext *context, + llvm::StringRef transformFileName, + OwningOpRef &transformModule); + +/// Utility to load a transform interpreter `module` from a module that has +/// already been preloaded in the context. +/// This mode is useful in cases where explicit parsing of a transform library +/// from file is expected to be prohibitively expensive. +/// In such cases, the transform module is expected to be found in the preloaded +/// library modules of the transform dialect. +/// Returns null if the module is not found. +ModuleOp getPreloadedTransformModule(MLIRContext *context); + +/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName` +/// that is either: +/// 1. nested under `root` (takes precedence). +/// 2. nested under `module`, if not found in `root`. +/// Reports errors and returns null if no such operation found. +TransformOpInterface findTransformEntryPoint( + Operation *root, ModuleOp module, + StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName); + +/// Merge all symbols from `other` into `target`. Both ops need to implement the +/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be +/// modified by this function and might not verify after the function returns. +/// Upon merging, private symbols may be renamed in order to avoid collisions in +/// the result. Public symbols may not collide, with the exception of +/// instances of `SymbolOpInterface`, where collisions are allowed if at least +/// one of the two is external, in which case the other op preserved (or any one +/// of the two if both are external). +// TODO: Reconsider cloning individual ops rather than forcing users of the +// function to clone (or move) `other` in order to improve efficiency. +// This might primarily make sense if we can also prune the symbols that +// are merged to a subset (such as those that are actually used). +LogicalResult mergeSymbolsInto(Operation *target, + OwningOpRef other); +} // namespace detail + +/// Standalone util to apply the named sequence `entryPoint` to the payload. +/// This is done in 3 steps: +/// 1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by +/// calling detail::findTransformEntryPoint. +/// 2. if the entry point is found and not nested under +/// `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in +/// the `sharedTransformModule`. Note: this may modify the transform IR +/// embedded with the payload IR. +/// 3. apply the transform IR to the payload IR, relaxing the requirement that +/// the transform IR is a top-level transform op. We are applying a named +/// sequence anyway. +LogicalResult applyTransformNamedSequence( + Operation *payload, ModuleOp transformModule, + const TransformOptions &options, + StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName); + +} // namespace transform +} // namespace mlir + +#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 4a9bb2dba7d66..4f88b8522e54c 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -2079,20 +2079,20 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) { // Entry point. //===----------------------------------------------------------------------===// -LogicalResult -transform::applyTransforms(Operation *payloadRoot, - TransformOpInterface transform, - const RaggedArray &extraMapping, - const TransformOptions &options) { -#ifndef NDEBUG - if (!transform->hasTrait() || - transform->getNumOperands() != 0) { - transform->emitError() - << "expected transform to start at the top-level transform op"; - llvm::report_fatal_error("could not run transforms", - /*gen_crash_diag=*/false); +LogicalResult transform::applyTransforms( + Operation *payloadRoot, TransformOpInterface transform, + const RaggedArray &extraMapping, + const TransformOptions &options, bool enforceToplevelTransformOp) { + if (enforceToplevelTransformOp) { + if (!transform->hasTrait() || + transform->getNumOperands() != 0) { + return transform->emitError() + << "expected transform to start at the top-level transform op"; + } + } else if (failed( + detail::verifyPossibleTopLevelTransformOpTrait(transform))) { + return failure(); } -#endif // NDEBUG TransformState state(transform->getParentRegion(), payloadRoot, extraMapping, options); diff --git a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt index 3f51ef1088f7a..8774a8b86fb0d 100644 --- a/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms CheckUses.cpp InferEffects.cpp TransformInterpreterPassBase.cpp + TransformInterpreterUtils.cpp DEPENDS MLIRTransformDialectTransformsIncGen diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 764d7e2585420..ebfd7269f696b 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Visitors.h" @@ -51,34 +52,6 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue = constexpr static llvm::StringLiteral kTransformDialectTagTransformContainerValue = "transform_container"; -/// Utility to parse the content of a `transformFileName` MLIR file containing -/// a transform dialect specification. -static LogicalResult -parseTransformModuleFromFile(MLIRContext *context, - llvm::StringRef transformFileName, - OwningOpRef &transformModule) { - if (transformFileName.empty()) { - LLVM_DEBUG( - DBGS() << "no transform file name specified, assuming the transform " - "module is embedded in the IR next to the top-level\n"); - return success(); - } - // Parse transformFileName content into a ModuleOp. - std::string errorMessage; - auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage); - if (!memoryBuffer) { - return emitError(FileLineColLoc::get( - StringAttr::get(context, transformFileName), 0, 0)) - << "failed to open transform file: " << errorMessage; - } - // Tell sourceMgr about this buffer, the parser will pick it up. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); - transformModule = - OwningOpRef(parseSourceFile(sourceMgr, context)); - return success(); -} - /// Finds the single top-level transform operation with `root` as ancestor. /// Reports an error if there is more than one such operation and returns the /// first one found. Reports an error returns nullptr if no such operation @@ -295,239 +268,6 @@ static void performOptionalDebugActions( transform->removeAttr(kTransformDialectTagAttrName); } -/// Return whether `func1` can be merged into `func2`. For that to work `func1` -/// has to be a declaration (aka has to be external) and `func2` either has to -/// be a declaration as well, or it has to be public (otherwise, it wouldn't -/// be visible by `func1`). -static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { - return func1.isExternal() && (func2.isPublic() || func2.isExternal()); -} - -/// Merge `func1` into `func2`. The two ops must be inside the same parent op -/// and mergable according to `canMergeInto`. The function erases `func1` such -/// that only `func2` exists when the function returns. -static LogicalResult mergeInto(FunctionOpInterface func1, - FunctionOpInterface func2) { - assert(canMergeInto(func1, func2)); - assert(func1->getParentOp() == func2->getParentOp() && - "expected func1 and func2 to be in the same parent op"); - - // Check that function signatures match. - if (func1.getFunctionType() != func2.getFunctionType()) { - return func1.emitError() - << "external definition has a mismatching signature (" - << func2.getFunctionType() << ")"; - } - - // Check and merge argument attributes. - MLIRContext *context = func1->getContext(); - auto *td = context->getLoadedDialect(); - StringAttr consumedName = td->getConsumedAttrName(); - StringAttr readOnlyName = td->getReadOnlyAttrName(); - for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { - bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr; - bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr; - bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr; - bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr; - if (!isExternalConsumed && !isExternalReadonly) { - if (isConsumed) - func2.setArgAttr(i, consumedName, UnitAttr::get(context)); - else if (isReadonly) - func2.setArgAttr(i, readOnlyName, UnitAttr::get(context)); - continue; - } - - if ((isExternalConsumed && !isConsumed) || - (isExternalReadonly && !isReadonly)) { - return func1.emitError() - << "external definition has mismatching consumption " - "annotations for argument #" - << i; - } - } - - // `func1` is the external one, so we can remove it. - assert(func1.isExternal()); - func1->erase(); - - return success(); -} - -/// Merge all symbols from `other` into `target`. Both ops need to implement the -/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be -/// modified by this function and might not verify after the function returns. -/// Upon merging, private symbols may be renamed in order to avoid collisions in -/// the result. Public symbols may not collide, with the exception of -/// instances of `SymbolOpInterface`, where collisions are allowed if at least -/// one of the two is external, in which case the other op preserved (or any one -/// of the two if both are external). -// TODO: Reconsider cloning individual ops rather than forcing users of the -// function to clone (or move) `other` in order to improve efficiency. -// This might primarily make sense if we can also prune the symbols that -// are merged to a subset (such as those that are actually used). -static LogicalResult mergeSymbolsInto(Operation *target, - OwningOpRef other) { - assert(target->hasTrait() && - "requires target to implement the 'SymbolTable' trait"); - assert(other->hasTrait() && - "requires target to implement the 'SymbolTable' trait"); - - SymbolTable targetSymbolTable(target); - SymbolTable otherSymbolTable(*other); - - // Step 1: - // - // Rename private symbols in both ops in order to resolve conflicts that can - // be resolved that way. - LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); - // TODO: Do we *actually* need to test in both directions? - for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( - SmallVector{&targetSymbolTable, &otherSymbolTable}, - SmallVector{&otherSymbolTable, - &targetSymbolTable})) { - Operation *symbolTableOp = symbolTable->getOp(); - for (Operation &op : symbolTableOp->getRegion(0).front()) { - auto symbolOp = dyn_cast(op); - if (!symbolOp) - continue; - StringAttr name = symbolOp.getNameAttr(); - LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n"); - - // Check if there is a colliding op in the other module. - auto collidingOp = - cast_or_null(otherSymbolTable->lookup(name)); - if (!collidingOp) - continue; - - LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); - - // Collisions are fine if both opt are functions and can be merged. - if (auto funcOp = dyn_cast(op), - collidingFuncOp = - dyn_cast(collidingOp.getOperation()); - funcOp && collidingFuncOp) { - if (canMergeInto(funcOp, collidingFuncOp) || - canMergeInto(collidingFuncOp, funcOp)) { - LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " - "will be merged\n"); - continue; - } - - // If they can't be merged, proceed like any other collision. - LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); - } - - // Collision can be resolved by renaming if one of the ops is private. - auto renameToUnique = - [&](SymbolOpInterface op, SymbolOpInterface otherOp, - SymbolTable &symbolTable, - SymbolTable &otherSymbolTable) -> LogicalResult { - LLVM_DEBUG(llvm::dbgs() << ", renaming\n"); - FailureOr maybeNewName = - symbolTable.renameToUnique(op, {&otherSymbolTable}); - if (failed(maybeNewName)) { - InFlightDiagnostic diag = op->emitError("failed to rename symbol"); - diag.attachNote(otherOp->getLoc()) - << "attempted renaming due to collision with this op"; - return diag; - } - LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() - << "\n"); - return success(); - }; - - if (symbolOp.isPrivate()) { - if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable, - *otherSymbolTable))) - return failure(); - continue; - } - if (collidingOp.isPrivate()) { - if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable, - *symbolTable))) - return failure(); - continue; - } - - LLVM_DEBUG(llvm::dbgs() << ", emitting error\n"); - InFlightDiagnostic diag = symbolOp.emitError() - << "doubly defined symbol @" << name.getValue(); - diag.attachNote(collidingOp->getLoc()) << "previously defined here"; - return diag; - } - } - - // TODO: This duplicates pass infrastructure. We should split this pass into - // several and let the pass infrastructure do the verification. - for (auto *op : SmallVector{target, *other}) { - if (failed(mlir::verify(op))) - return op->emitError() << "failed to verify input op after renaming"; - } - - // Step 2: - // - // Move all ops from `other` into target and merge public symbols. - LLVM_DEBUG(DBGS() << "moving all symbols into target\n"); - { - SmallVector opsToMove; - for (Operation &op : other->getRegion(0).front()) { - if (auto symbol = dyn_cast(op)) - opsToMove.push_back(symbol); - } - - for (SymbolOpInterface op : opsToMove) { - // Remember potentially colliding op in the target module. - auto collidingOp = cast_or_null( - targetSymbolTable.lookup(op.getNameAttr())); - - // Move op even if we get a collision. - LLVM_DEBUG(DBGS() << " moving @" << op.getName()); - op->moveBefore(&target->getRegion(0).front(), - target->getRegion(0).front().end()); - - // If there is no collision, we are done. - if (!collidingOp) { - LLVM_DEBUG(llvm::dbgs() << " without collision\n"); - continue; - } - - // The two colliding ops must both be functions because we have already - // emitted errors otherwise earlier. - auto funcOp = cast(op.getOperation()); - auto collidingFuncOp = - cast(collidingOp.getOperation()); - - // Both ops are in the target module now and can be treated symmetrically, - // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`. - if (!canMergeInto(funcOp, collidingFuncOp)) { - std::swap(funcOp, collidingFuncOp); - } - assert(canMergeInto(funcOp, collidingFuncOp)); - - LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " - << collidingFuncOp.getLoc() << ":\n" - << collidingFuncOp << "\n"); - - // Update symbol table. This works with or without the previous `swap`. - targetSymbolTable.remove(funcOp); - targetSymbolTable.insert(collidingFuncOp); - assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp); - - // Do the actual merging. - if (failed(mergeInto(funcOp, collidingFuncOp))) { - return failure(); - } - } - } - - if (failed(mlir::verify(target))) - return target->emitError() - << "failed to verify target op after merging symbols"; - - LLVM_DEBUG(DBGS() << "done merging ops\n"); - return success(); -} - LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, @@ -595,9 +335,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( diag.attachNote(target->getLoc()) << "pass anchor op"; return diag; } - if (failed( - mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot), - transformLibraryModule->get()->clone()))) + if (failed(detail::mergeSymbolsInto( + SymbolTable::getNearestSymbolTable(transformRoot), + transformLibraryModule->get()->clone()))) return emitError(transformRoot->getLoc(), "failed to merge library symbols into transform root"); } @@ -683,8 +423,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( OwningOpRef moduleFromFile; { auto loc = FileLineColLoc::get(context, transformFileName, 0, 0); - if (failed(parseTransformModuleFromFile(context, transformFileName, - moduleFromFile))) + if (failed(detail::parseTransformModuleFromFile(context, transformFileName, + moduleFromFile))) return emitError(loc) << "failed to parse transform module"; if (moduleFromFile && failed(mlir::verify(*moduleFromFile))) return emitError(loc) << "failed to verify transform module"; @@ -701,8 +441,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( for (const std::string &libraryFileName : libraryFileNames) { OwningOpRef parsedLibrary; auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0); - if (failed(parseTransformModuleFromFile(context, libraryFileName, - parsedLibrary))) + if (failed(detail::parseTransformModuleFromFile(context, libraryFileName, + parsedLibrary))) return emitError(loc) << "failed to parse transform library module"; if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) return emitError(loc) << "failed to verify transform library module"; @@ -741,8 +481,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( IRRewriter rewriter(context); // TODO: extend `mergeSymbolsInto` to support multiple `other` modules. for (OwningOpRef &parsedLibrary : parsedLibraries) { - if (failed(mergeSymbolsInto(mergedParsedLibraries.get(), - std::move(parsedLibrary)))) + if (failed(detail::mergeSymbolsInto(mergedParsedLibraries.get(), + std::move(parsedLibrary)))) return mergedParsedLibraries->emitError() << "failed to verify merged transform module"; } @@ -751,8 +491,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( // Use parsed libaries to resolve symbols in shared transform module or return // as separate library module. if (sharedTransformModule && *sharedTransformModule) { - if (failed(mergeSymbolsInto(sharedTransformModule->get(), - std::move(mergedParsedLibraries)))) + if (failed(detail::mergeSymbolsInto(sharedTransformModule->get(), + std::move(mergedParsedLibraries)))) return (*sharedTransformModule)->emitError() << "failed to merge symbols from library files " "into shared transform module"; diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp new file mode 100644 index 0000000000000..1a6ebdd16232e --- /dev/null +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp @@ -0,0 +1,337 @@ +//===- TransformInterpreterUtils.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Lightweight transform dialect interpreter utilities. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/FileUtilities.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +#define DEBUG_TYPE "transform-dialect-interpreter-utils" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") + +LogicalResult transform::detail::parseTransformModuleFromFile( + MLIRContext *context, llvm::StringRef transformFileName, + OwningOpRef &transformModule) { + if (transformFileName.empty()) { + LLVM_DEBUG( + DBGS() << "no transform file name specified, assuming the transform " + "module is embedded in the IR next to the top-level\n"); + return success(); + } + // Parse transformFileName content into a ModuleOp. + std::string errorMessage; + auto memoryBuffer = mlir::openInputFile(transformFileName, &errorMessage); + if (!memoryBuffer) { + return emitError(FileLineColLoc::get( + StringAttr::get(context, transformFileName), 0, 0)) + << "failed to open transform file: " << errorMessage; + } + // Tell sourceMgr about this buffer, the parser will pick it up. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(memoryBuffer), llvm::SMLoc()); + transformModule = + OwningOpRef(parseSourceFile(sourceMgr, context)); + return mlir::verify(*transformModule); +} + +ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) { + auto preloadedLibraryRange = + context->getOrLoadDialect() + ->getLibraryModules(); + if (!preloadedLibraryRange.empty()) + return *preloadedLibraryRange.begin(); + return ModuleOp(); +} + +transform::TransformOpInterface +transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module, + StringRef entryPoint) { + SmallVector l{root}; + if (module) + l.push_back(module); + for (Operation *op : l) { + transform::TransformOpInterface transform = nullptr; + op->walk( + [&](transform::NamedSequenceOp namedSequenceOp) { + if (namedSequenceOp.getSymName() == entryPoint) { + transform = cast( + namedSequenceOp.getOperation()); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (transform) + return transform; + } + auto diag = root->emitError() + << "could not find a nested named sequence with name: " + << entryPoint; + return nullptr; +} + +/// Return whether `func1` can be merged into `func2`. For that to work `func1` +/// has to be a declaration (aka has to be external) and `func2` either has to +/// be a declaration as well, or it has to be public (otherwise, it wouldn't +/// be visible by `func1`). +static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { + return func1.isExternal() && (func2.isPublic() || func2.isExternal()); +} + +/// Merge `func1` into `func2`. The two ops must be inside the same parent op +/// and mergable according to `canMergeInto`. The function erases `func1` such +/// that only `func2` exists when the function returns. +static LogicalResult mergeInto(FunctionOpInterface func1, + FunctionOpInterface func2) { + assert(canMergeInto(func1, func2)); + assert(func1->getParentOp() == func2->getParentOp() && + "expected func1 and func2 to be in the same parent op"); + + // Check that function signatures match. + if (func1.getFunctionType() != func2.getFunctionType()) { + return func1.emitError() + << "external definition has a mismatching signature (" + << func2.getFunctionType() << ")"; + } + + // Check and merge argument attributes. + MLIRContext *context = func1->getContext(); + auto *td = context->getLoadedDialect(); + StringAttr consumedName = td->getConsumedAttrName(); + StringAttr readOnlyName = td->getReadOnlyAttrName(); + for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { + bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr; + bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr; + bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr; + bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr; + if (!isExternalConsumed && !isExternalReadonly) { + if (isConsumed) + func2.setArgAttr(i, consumedName, UnitAttr::get(context)); + else if (isReadonly) + func2.setArgAttr(i, readOnlyName, UnitAttr::get(context)); + continue; + } + + if ((isExternalConsumed && !isConsumed) || + (isExternalReadonly && !isReadonly)) { + return func1.emitError() + << "external definition has mismatching consumption " + "annotations for argument #" + << i; + } + } + + // `func1` is the external one, so we can remove it. + assert(func1.isExternal()); + func1->erase(); + + return success(); +} + +LogicalResult +transform::detail::mergeSymbolsInto(Operation *target, + OwningOpRef other) { + assert(target->hasTrait() && + "requires target to implement the 'SymbolTable' trait"); + assert(other->hasTrait() && + "requires target to implement the 'SymbolTable' trait"); + + SymbolTable targetSymbolTable(target); + SymbolTable otherSymbolTable(*other); + + // Step 1: + // + // Rename private symbols in both ops in order to resolve conflicts that can + // be resolved that way. + LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); + // TODO: Do we *actually* need to test in both directions? + for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( + SmallVector{&targetSymbolTable, &otherSymbolTable}, + SmallVector{&otherSymbolTable, + &targetSymbolTable})) { + Operation *symbolTableOp = symbolTable->getOp(); + for (Operation &op : symbolTableOp->getRegion(0).front()) { + auto symbolOp = dyn_cast(op); + if (!symbolOp) + continue; + StringAttr name = symbolOp.getNameAttr(); + LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n"); + + // Check if there is a colliding op in the other module. + auto collidingOp = + cast_or_null(otherSymbolTable->lookup(name)); + if (!collidingOp) + continue; + + LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); + + // Collisions are fine if both opt are functions and can be merged. + if (auto funcOp = dyn_cast(op), + collidingFuncOp = + dyn_cast(collidingOp.getOperation()); + funcOp && collidingFuncOp) { + if (canMergeInto(funcOp, collidingFuncOp) || + canMergeInto(collidingFuncOp, funcOp)) { + LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " + "will be merged\n"); + continue; + } + + // If they can't be merged, proceed like any other collision. + LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); + } + + // Collision can be resolved by renaming if one of the ops is private. + auto renameToUnique = + [&](SymbolOpInterface op, SymbolOpInterface otherOp, + SymbolTable &symbolTable, + SymbolTable &otherSymbolTable) -> LogicalResult { + LLVM_DEBUG(llvm::dbgs() << ", renaming\n"); + FailureOr maybeNewName = + symbolTable.renameToUnique(op, {&otherSymbolTable}); + if (failed(maybeNewName)) { + InFlightDiagnostic diag = op->emitError("failed to rename symbol"); + diag.attachNote(otherOp->getLoc()) + << "attempted renaming due to collision with this op"; + return diag; + } + LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() + << "\n"); + return success(); + }; + + if (symbolOp.isPrivate()) { + if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable, + *otherSymbolTable))) + return failure(); + continue; + } + if (collidingOp.isPrivate()) { + if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable, + *symbolTable))) + return failure(); + continue; + } + LLVM_DEBUG(llvm::dbgs() << ", emitting error\n"); + InFlightDiagnostic diag = symbolOp.emitError() + << "doubly defined symbol @" << name.getValue(); + diag.attachNote(collidingOp->getLoc()) << "previously defined here"; + return diag; + } + } + + // TODO: This duplicates pass infrastructure. We should split this pass into + // several and let the pass infrastructure do the verification. + for (auto *op : SmallVector{target, *other}) { + if (failed(mlir::verify(op))) + return op->emitError() << "failed to verify input op after renaming"; + } + + // Step 2: + // + // Move all ops from `other` into target and merge public symbols. + LLVM_DEBUG(DBGS() << "moving all symbols into target\n"); + { + SmallVector opsToMove; + for (Operation &op : other->getRegion(0).front()) { + if (auto symbol = dyn_cast(op)) + opsToMove.push_back(symbol); + } + + for (SymbolOpInterface op : opsToMove) { + // Remember potentially colliding op in the target module. + auto collidingOp = cast_or_null( + targetSymbolTable.lookup(op.getNameAttr())); + + // Move op even if we get a collision. + LLVM_DEBUG(DBGS() << " moving @" << op.getName()); + op->moveBefore(&target->getRegion(0).front(), + target->getRegion(0).front().end()); + + // If there is no collision, we are done. + if (!collidingOp) { + LLVM_DEBUG(llvm::dbgs() << " without collision\n"); + continue; + } + + // The two colliding ops must both be functions because we have already + // emitted errors otherwise earlier. + auto funcOp = cast(op.getOperation()); + auto collidingFuncOp = + cast(collidingOp.getOperation()); + + // Both ops are in the target module now and can be treated symmetrically, + // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`. + if (!canMergeInto(funcOp, collidingFuncOp)) { + std::swap(funcOp, collidingFuncOp); + } + assert(canMergeInto(funcOp, collidingFuncOp)); + + LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " + << collidingFuncOp.getLoc() << ":\n" + << collidingFuncOp << "\n"); + + // Update symbol table. This works with or without the previous `swap`. + targetSymbolTable.remove(funcOp); + targetSymbolTable.insert(collidingFuncOp); + assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp); + + // Do the actual merging. + if (failed(mergeInto(funcOp, collidingFuncOp))) { + return failure(); + } + } + } + + if (failed(mlir::verify(target))) + return target->emitError() + << "failed to verify target op after merging symbols"; + + LLVM_DEBUG(DBGS() << "done merging ops\n"); + return success(); +} + +LogicalResult transform::applyTransformNamedSequence( + Operation *payload, ModuleOp transformModule, + const TransformOptions &options, StringRef entryPoint) { + Operation *transformRoot = + detail::findTransformEntryPoint(payload, transformModule, entryPoint); + if (!transformRoot) + return failure(); + + // `transformModule` may not be modified. + OwningOpRef clonedTransformModule(transformModule->clone()); + if (transformModule && !transformModule->isAncestor(transformRoot)) { + if (failed(detail::mergeSymbolsInto( + SymbolTable::getNearestSymbolTable(transformRoot), + std::move(clonedTransformModule)))) + return failure(); + } + + // Apply the transform to the IR, do not enforce top-level constraints. + RaggedArray noExtraMappings; + return applyTransforms(payload, cast(transformRoot), + noExtraMappings, options, + /*enforceToplevelTransformOp=*/false); +} diff --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt index 1fecd21221c91..89238a0bdae16 100644 --- a/mlir/unittests/Dialect/Transform/CMakeLists.txt +++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt @@ -1,8 +1,11 @@ add_mlir_unittest(MLIRTransformDialectTests BuildOnlyExtensionTest.cpp + Preload.cpp ) target_link_libraries(MLIRTransformDialectTests PRIVATE MLIRFuncDialect + MLIRTestTransformDialect MLIRTransformDialect + MLIRTransformDialectTransforms ) diff --git a/mlir/unittests/Dialect/Transform/Preload.cpp b/mlir/unittests/Dialect/Transform/Preload.cpp new file mode 100644 index 0000000000000..d3c3044e0e0f7 --- /dev/null +++ b/mlir/unittests/Dialect/Transform/Preload.cpp @@ -0,0 +1,92 @@ +//===- Preload.cpp - Test MlirOptMain parameterization ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +using namespace mlir; + +namespace mlir { +namespace test { +std::unique_ptr createTestTransformDialectInterpreterPass(); +} // namespace test +} // namespace mlir +namespace test { +void registerTestTransformDialectExtension(DialectRegistry ®istry); +} // namespace test + +const static llvm::StringLiteral library = R"MLIR( +module attributes {transform.with_named_sequence} { + transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op + transform.yield + } +})MLIR"; + +const static llvm::StringLiteral input = R"MLIR( +module attributes {transform.with_named_sequence} { + transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) + + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> () + } +})MLIR"; + +TEST(Preload, ContextPreloadConstructedLibrary) { + registerPassManagerCLOptions(); + + MLIRContext context; + auto *dialect = context.getOrLoadDialect(); + DialectRegistry registry; + ::test::registerTestTransformDialectExtension(registry); + registry.applyExtensions(&context); + ParserConfig parserConfig(&context); + + OwningOpRef inputModule = + parseSourceString(input, parserConfig, ""); + EXPECT_TRUE(inputModule) << "failed to parse input module"; + + OwningOpRef transformLibrary = + parseSourceString(library, parserConfig, ""); + EXPECT_TRUE(transformLibrary) << "failed to parse transform module"; + dialect->registerLibraryModule(std::move(transformLibrary)); + + ModuleOp retrievedTransformLibrary = + transform::detail::getPreloadedTransformModule(&context); + EXPECT_TRUE(retrievedTransformLibrary) + << "failed to retrieve transform module"; + + transform::TransformOpInterface entryPoint = + transform::detail::findTransformEntryPoint(inputModule->getOperation(), + retrievedTransformLibrary); + EXPECT_TRUE(entryPoint) << "failed to find entry point"; + + OwningOpRef clonedTransformModule( + retrievedTransformLibrary->clone()); + LogicalResult res = transform::detail::mergeSymbolsInto( + inputModule->getOperation(), std::move(clonedTransformModule)); + EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols"; + + transform::TransformOptions options; + res = transform::applyTransformNamedSequence( + inputModule->getOperation(), retrievedTransformLibrary, options); + EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence"; +}