Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/include/clang/Frontend/FrontendOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ class FrontendOptions {
std::string ClangIRIdiomRecognizerOpts;
std::string ClangIRLibOptOpts;

frontend::MLIRDialectKind MLIRTargetDialect;
frontend::MLIRDialectKind MLIRTargetDialect = frontend::MLIR_CORE;

/// The input kind, either specified via -x argument or deduced from the input
/// file name.
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_clang_library(MLIRCIR
CIRMemorySlot.cpp
CIRTypes.cpp
FPEnv.cpp
CIRLinkerInterface.cpp

DEPENDS
MLIRBuiltinLocationAttributesIncGen
Expand Down
1 change: 0 additions & 1 deletion clang/lib/CIR/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
add_clang_library(MLIRCIRInterfaces
ASTAttrInterfaces.cpp
CIROpInterfaces.cpp
CIRLinkerInterface.cpp
CIRLoopOpInterface.cpp
CIRFPTypeInterface.cpp

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ SmallVector<IntT> convertArrayToIndices(ArrayAttr attrs) {
return convertArrayToIndices<IntT>(attrs.getValue());
}


/// Register the `LLVMLinkerInterface` implementation of `LinkerInterface`
/// within the LLVM dialect.
void registerLinkerInterface(DialectRegistry &registry);
Expand Down
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#define MLIR_DIALECT_LLVMIR_LLVMINTERFACES_H_

#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"

namespace mlir {
namespace LLVM {
namespace detail {
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMLinkerInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef MLIR_DIALECT_LLVMIR_LLVMLINKERINTERFACE_H
#define MLIR_DIALECT_LLVMIR_LLVMLINKERINTERFACE_H

#include "mlir/Linker/LLVMLinkerMixin.h"
namespace mlir {
namespace LLVM {

class LLVMSymbolLinkerInterface
: public link::SymbolAttrLLVMLinkerInterface<LLVMSymbolLinkerInterface> {
public:
LLVMSymbolLinkerInterface(Dialect *dialect);

bool canBeLinked(Operation *op) const override;
static Linkage getLinkage(Operation *op);
static Visibility getVisibility(Operation *op);
static void setVisibility(Operation *op, Visibility visibility);
static bool isDeclaration(Operation *op);
static unsigned getBitWidth(Operation *op);
static UnnamedAddr getUnnamedAddr(Operation *op);
static void setUnnamedAddr(Operation *op, UnnamedAddr val);
};

} // namespace LLVM
} // namespace mlir

#endif // MLIR_DIALECT_LLVMIR_LLVMLINKERINTERFACE_H
37 changes: 8 additions & 29 deletions mlir/include/mlir/Linker/LLVMLinkerMixin.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,6 @@ static UnnamedAddr getMinUnnamedAddr(UnnamedAddr lhs, UnnamedAddr rhs) {
// LLVMLinkerMixin
//===----------------------------------------------------------------------===//

enum class ConflictResolution {
LinkFromSrc,
LinkFromDst,
LinkFromBothAndRenameDst,
LinkFromBothAndRenameSrc,
};

template <typename DerivedLinkerInterface>
class LLVMLinkerMixin {
const DerivedLinkerInterface &getDerived() const {
Expand Down Expand Up @@ -198,7 +191,7 @@ class LLVMLinkerMixin {
isAvailableExternallyLinkage(srcLinkage));
}

LogicalResult verifyLinkageCompatibility(Conflict pair) {
LogicalResult verifyLinkageCompatibility(Conflict pair) const{
const DerivedLinkerInterface &derived = getDerived();
assert(derived.canBeLinked(pair.src) && "expected linkable operation");
assert(derived.canBeLinked(pair.dst) && "expected linkable operation");
Expand All @@ -221,7 +214,7 @@ class LLVMLinkerMixin {
return success();
}

ConflictResolution resolveConflict(Conflict pair) {
ConflictResolution getConflictResolution(Conflict pair) const {
const DerivedLinkerInterface &derived = getDerived();
assert(derived.canBeLinked(pair.src) && "expected linkable operation");
assert(derived.canBeLinked(pair.dst) && "expected linkable operation");
Expand Down Expand Up @@ -318,28 +311,14 @@ class SymbolAttrLLVMLinkerInterface
return LinkerMixin::isLinkNeeded(pair, forDependency);
}

LogicalResult resolveConflict(Conflict pair) override {
if (failed(LinkerMixin::verifyLinkageCompatibility(pair)))
return failure();
ConflictResolution resolution = LinkerMixin::resolveConflict(pair);

switch (resolution) {
case ConflictResolution::LinkFromSrc:
registerForLink(pair.src);
return success();
case ConflictResolution::LinkFromDst:
return success();
case ConflictResolution::LinkFromBothAndRenameDst:
uniqued.insert(pair.dst);
registerForLink(pair.src);
return success();
case ConflictResolution::LinkFromBothAndRenameSrc:
uniqued.insert(pair.src);
return success();
}
LogicalResult verifyLinkageCompatibility(Conflict pair) const override {
return LinkerMixin::verifyLinkageCompatibility(pair);
}

llvm_unreachable("unimplemented conflict resolution");
ConflictResolution getConflictResolution(Conflict pair) const override {
return LinkerMixin::getConflictResolution(pair);
}

};

} // namespace mlir::link
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Linker/LinkerInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/IRMapping.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Error.h"

namespace mlir::link {

Expand All @@ -43,6 +44,10 @@ class LinkState {

Operation *remapped(Operation *src) const;

LinkState nest(ModuleOp submod) const;

void updateState(const LinkState &substate);

private:
IRMapping mapping;
OpBuilder builder;
Expand Down Expand Up @@ -136,6 +141,12 @@ class SymbolLinkerInterface : public LinkerInterface<SymbolLinkerInterface> {
//===----------------------------------------------------------------------===//
// SymbolAttrLinkerInterface
//===----------------------------------------------------------------------===//
enum class ConflictResolution {
LinkFromSrc,
LinkFromDst,
LinkFromBothAndRenameDst,
LinkFromBothAndRenameSrc,
};

class SymbolAttrLinkerInterface : public SymbolLinkerInterface {
public:
Expand All @@ -153,6 +164,16 @@ class SymbolAttrLinkerInterface : public SymbolLinkerInterface {
/// Records a non-conflicting operation for linking.
void registerForLink(Operation *op) override;

/// Resolves a conflict between an existing operation and a new one.
LogicalResult resolveConflict(Conflict pair) override;

virtual LogicalResult resolveConflict(Conflict pair, ConflictResolution resolution);

/// Gets the conflict resolution for a given conflict
virtual ConflictResolution getConflictResolution(Conflict pair) const = 0;

virtual LogicalResult verifyLinkageCompatibility(Conflict pair) const = 0;

/// Dependencies of the given operation required to be linked.
SmallVector<Operation *> dependencies(Operation *op) const override;

Expand Down
124 changes: 62 additions & 62 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMLinkerInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,88 +13,88 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Linker/LLVMLinkerMixin.h"
#include "mlir/Linker/LinkerInterface.h"

#include "mlir/Dialect/LLVMIR/LLVMLinkerInterface.h"
using namespace mlir;
using namespace mlir::link;

//===----------------------------------------------------------------------===//
// LLVMSymbolLinkerInterface
//===----------------------------------------------------------------------===//

class LLVMSymbolLinkerInterface
: public SymbolAttrLLVMLinkerInterface<LLVMSymbolLinkerInterface> {
public:
LLVMSymbolLinkerInterface(Dialect *dialect)
: SymbolAttrLLVMLinkerInterface(dialect) {}

bool canBeLinked(Operation *op) const override {
return isa<LLVM::GlobalOp>(op) || isa<LLVM::LLVMFuncOp>(op);
}


mlir::LLVM::LLVMSymbolLinkerInterface::LLVMSymbolLinkerInterface(Dialect *dialect)
: SymbolAttrLLVMLinkerInterface(dialect) {}

bool mlir::LLVM::LLVMSymbolLinkerInterface::canBeLinked(Operation *op) const {
return isa<LLVM::GlobalOp>(op) || isa<LLVM::LLVMFuncOp>(op);
}

//===--------------------------------------------------------------------===//
// LLVMLinkerMixin required methods from derived linker interface
//===--------------------------------------------------------------------===//

static Linkage getLinkage(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.getLinkage();
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.getLinkage();
llvm_unreachable("unexpected operation");
}
Linkage mlir::LLVM::LLVMSymbolLinkerInterface::getLinkage(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.getLinkage();
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.getLinkage();
llvm_unreachable("unexpected operation");
}

static Visibility getVisibility(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.getVisibility_();
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.getVisibility_();
llvm_unreachable("unexpected operation");
}
Visibility mlir::LLVM::LLVMSymbolLinkerInterface::getVisibility(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.getVisibility_();
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.getVisibility_();
llvm_unreachable("unexpected operation");
}

static void setVisibility(Operation *op, Visibility visibility) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.setVisibility_(visibility);
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.setVisibility_(visibility);
llvm_unreachable("unexpected operation");
}
void mlir::LLVM::LLVMSymbolLinkerInterface::setVisibility(Operation *op, Visibility visibility) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.setVisibility_(visibility);
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.setVisibility_(visibility);
llvm_unreachable("unexpected operation");
}

// Return true if the primary definition of this global value is outside of
// the current translation unit.
static bool isDeclaration(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.getInitializerRegion().empty() && !gv.getValue();
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.getBody().empty();
llvm_unreachable("unexpected operation");
}
// Return true if the primary definition of this global value is outside of
// the current translation unit.
bool mlir::LLVM::LLVMSymbolLinkerInterface::isDeclaration(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.getInitializerRegion().empty() && !gv.getValue();
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.getBody().empty();
llvm_unreachable("unexpected operation");
}

static unsigned getBitWidth(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.getType().getIntOrFloatBitWidth();
llvm_unreachable("unexpected operation");
}
unsigned mlir::LLVM::LLVMSymbolLinkerInterface::getBitWidth(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.getType().getIntOrFloatBitWidth();
llvm_unreachable("unexpected operation");
}

static UnnamedAddr getUnnamedAddr(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op)) {
auto addr = gv.getUnnamedAddr();
return addr ? *addr : UnnamedAddr::Global;
}
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op)) {
auto addr = fn.getUnnamedAddr();
return addr ? *addr : UnnamedAddr::Global;
}
llvm_unreachable("unexpected operation");
UnnamedAddr mlir::LLVM::LLVMSymbolLinkerInterface::getUnnamedAddr(Operation *op) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op)) {
auto addr = gv.getUnnamedAddr();
return addr ? *addr : UnnamedAddr::Global;
}

static void setUnnamedAddr(Operation *op, UnnamedAddr val) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.setUnnamedAddr(val);
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.setUnnamedAddr(val);
llvm_unreachable("unexpected operation");
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op)) {
auto addr = fn.getUnnamedAddr();
return addr ? *addr : UnnamedAddr::Global;
}
};
llvm_unreachable("unexpected operation");
}

void mlir::LLVM::LLVMSymbolLinkerInterface::setUnnamedAddr(Operation *op, UnnamedAddr val) {
if (auto gv = dyn_cast<LLVM::GlobalOp>(op))
return gv.setUnnamedAddr(val);
if (auto fn = dyn_cast<LLVM::LLVMFuncOp>(op))
return fn.setUnnamedAddr(val);
llvm_unreachable("unexpected operation");
}


//===----------------------------------------------------------------------===//
// registerLinkerInterface
Expand Down
42 changes: 42 additions & 0 deletions mlir/lib/Linker/LinkerInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ Operation *LinkState::remapped(Operation *src) const {
return mapping.lookupOrNull(src);
}

LinkState LinkState::nest(ModuleOp submod) const {
assert(submod->getParentOfType<mlir::ModuleOp>().getOperation() ==
getDestinationOp() &&
"Submodule should be directly nested in the current state");
LinkState submodState(submod);
submodState.mapping = mapping;
return submodState;
}

void LinkState::updateState(const LinkState &substate) {
mapping = substate.mapping;
}

//===----------------------------------------------------------------------===//
// SymbolAttrLinkerInterface
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -167,3 +180,32 @@ SymbolAttrLinkerInterface::dependencies(Operation *op) const {

return result;
}

LogicalResult
SymbolAttrLinkerInterface::resolveConflict(Conflict pair,
ConflictResolution resolution) {

switch (resolution) {
case ConflictResolution::LinkFromSrc:
registerForLink(pair.src);
return success();
case ConflictResolution::LinkFromDst:
return success();
case ConflictResolution::LinkFromBothAndRenameDst:
uniqued.insert(pair.dst);
registerForLink(pair.src);
return success();
case ConflictResolution::LinkFromBothAndRenameSrc:
uniqued.insert(pair.src);
return success();
}

llvm_unreachable("unimplemented conflict resolution");
}

LogicalResult SymbolAttrLinkerInterface::resolveConflict(Conflict pair) {
if (failed(this->verifyLinkageCompatibility(pair)))
return failure();
ConflictResolution resolution = this->getConflictResolution(pair);
return resolveConflict(pair, resolution);
}
Loading