Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,11 @@ class ModuleImport {
/// Converts the callee's function type. For direct calls, it converts the
/// actual function type, which may differ from the called operand type in
/// variadic functions. For indirect calls, it converts the function type
/// associated with the call instruction. Returns failure when the call and
/// the callee are not compatible or when nested type conversions failed.
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
/// associated with the call instruction. When the call and the callee are not
/// compatible (or when nested type conversions failed), emit a warning and
/// update `isIncompatibleCall` to indicate it.
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst,
bool &isIncompatibleCall);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
/// Converts the parameter and result attributes attached to `func` and adds
Expand Down
72 changes: 55 additions & 17 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1612,8 +1612,8 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
/// Checks if `callType` and `calleeType` are compatible and can be represented
/// in MLIR.
static LogicalResult
verifyFunctionTypeCompatibility(LLVMFunctionType callType,
LLVMFunctionType calleeType) {
checkFunctionTypeCompatibility(LLVMFunctionType callType,
LLVMFunctionType calleeType) {
if (callType.getReturnType() != calleeType.getReturnType())
return failure();

Expand All @@ -1639,7 +1639,9 @@ verifyFunctionTypeCompatibility(LLVMFunctionType callType,
}

FailureOr<LLVMFunctionType>
ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
ModuleImport::convertFunctionType(llvm::CallBase *callInst,
bool &isIncompatibleCall) {
isIncompatibleCall = false;
auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
if (!funcTy)
Expand All @@ -1662,11 +1664,14 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
if (failed(calleeType))
return failure();

// Compare the types to avoid constructing illegal call/invoke operations.
if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
// Compare the types and notify users via `isIncompatibleCall` if they are not
// compatible.
if (failed(checkFunctionTypeCompatibility(*callType, *calleeType))) {
isIncompatibleCall = true;
Location loc = translateLoc(callInst->getDebugLoc());
return emitError(loc) << "incompatible call and callee types: " << *callType
<< " and " << *calleeType;
emitWarning(loc) << "incompatible call and callee types: " << *callType
<< " and " << *calleeType;
return callType;
}

return calleeType;
Expand Down Expand Up @@ -1783,16 +1788,34 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
/*operand_attrs=*/nullptr)
.getOperation();
}
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
bool isIncompatibleCall;
FailureOr<LLVMFunctionType> funcTy =
convertFunctionType(callInst, isIncompatibleCall);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr callee = convertCalleeName(callInst);
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
FlatSymbolRefAttr callee = nullptr;
if (isIncompatibleCall) {
// Use an indirect call (in order to represent valid and verifiable LLVM
// IR). Build the indirect call by passing an empty `callee` operand and
// insert into `operands` to include the indirect call target.
FlatSymbolRefAttr calleeSym = convertCalleeName(callInst);
Value indirectCallVal = builder.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context), calleeSym);
operands->insert(operands->begin(), indirectCallVal);
} else {
// Regular direct call using callee name.
callee = convertCalleeName(callInst);
}
CallOp callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);

if (failed(convertCallAttributes(callInst, callOp)))
return failure();
// Handle parameter and result attributes.
convertParameterAttributes(callInst, callOp, builder);

// Handle parameter and result attributes unless it's an incompatible
// call.
if (!isIncompatibleCall)
convertParameterAttributes(callInst, callOp, builder);
return callOp.getOperation();
}();

Expand Down Expand Up @@ -1857,12 +1880,25 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
unwindArgs)))
return failure();

FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
bool isIncompatibleInvoke;
FailureOr<LLVMFunctionType> funcTy =
convertFunctionType(invokeInst, isIncompatibleInvoke);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);

FlatSymbolRefAttr calleeName = nullptr;
if (isIncompatibleInvoke) {
// Use an indirect invoke (in order to represent valid and verifiable LLVM
// IR). Build the indirect invoke by passing an empty `callee` operand and
// insert into `operands` to include the indirect invoke target.
FlatSymbolRefAttr calleeSym = convertCalleeName(invokeInst);
Value indirectInvokeVal = builder.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context), calleeSym);
operands->insert(operands->begin(), indirectInvokeVal);
} else {
// Regular direct invoke using callee name.
calleeName = convertCalleeName(invokeInst);
}
// Create the invoke operation. Normal destination block arguments will be
// added later on to handle the case in which the operation result is
// included in this list.
Expand All @@ -1873,8 +1909,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
return failure();

// Handle parameter and result attributes.
convertParameterAttributes(invokeInst, invokeOp, builder);
// Handle parameter and result attributes unless it's an incompatible
// invoke.
if (!isIncompatibleInvoke)
convertParameterAttributes(invokeInst, invokeOp, builder);

if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Target/LLVMIR/Import/import-failure.ll
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ target datalayout = "e-m-i64:64"
; // -----

; CHECK: <unknown>
; CHECK-SAME: incompatible call and callee types: '!llvm.func<void (i64)>' and '!llvm.func<void (ptr)>'
; CHECK-SAME: warning: incompatible call and callee types: '!llvm.func<void (i64)>' and '!llvm.func<void (ptr)>'
define void @incompatible_call_and_callee_types() {
call void @callee(i64 0)
ret void
Expand All @@ -373,7 +373,7 @@ declare void @callee(ptr)
; // -----

; CHECK: <unknown>
; CHECK-SAME: incompatible call and callee types: '!llvm.func<void ()>' and '!llvm.func<i32 ()>'
; CHECK-SAME: warning: incompatible call and callee types: '!llvm.func<void ()>' and '!llvm.func<i32 ()>'
define void @f() personality ptr @__gxx_personality_v0 {
entry:
invoke void @g() to label %bb1 unwind label %bb2
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Target/LLVMIR/Import/instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,19 @@ bb2:
declare void @g(...)

declare i32 @__gxx_personality_v0(...)

; // -----

; CHECK-LABEL: llvm.func @incompatible_call_and_callee_types
define void @incompatible_call_and_callee_types() {
; CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i64) : i64
; CHECK: %[[TARGET:.*]] = llvm.mlir.addressof @callee : !llvm.ptr
; CHECK: llvm.call %[[TARGET]](%[[CST]]) : !llvm.ptr, (i64) -> ()
call void @callee(i64 0)
; CHECK: llvm.return
ret void
}

define void @callee({ptr, i64}, i32) {
ret void
}