diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 3b164927d41fd..74cf53c24c4a3 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -358,9 +358,11 @@ class ModuleImport { /// Converts the callee's function type. For direct calls, it converts the /// actual function type, which may differ from the called operand type in /// variadic functions. For indirect calls, it converts the function type - /// associated with the call instruction. Returns failure when the call and - /// the callee are not compatible or when nested type conversions failed. - FailureOr convertFunctionType(llvm::CallBase *callInst); + /// associated with the call instruction. When the call and the callee are not + /// compatible (or when nested type conversions failed), emit a warning and + /// update `isIncompatibleCall` to indicate it. + FailureOr convertFunctionType(llvm::CallBase *callInst, + bool &isIncompatibleCall); /// Returns the callee name, or an empty symbol if the call is not direct. FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst); /// Converts the parameter and result attributes attached to `func` and adds diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index a07189ae1323c..c15405a3a4650 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1612,8 +1612,8 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst, /// Checks if `callType` and `calleeType` are compatible and can be represented /// in MLIR. static LogicalResult -verifyFunctionTypeCompatibility(LLVMFunctionType callType, - LLVMFunctionType calleeType) { +checkFunctionTypeCompatibility(LLVMFunctionType callType, + LLVMFunctionType calleeType) { if (callType.getReturnType() != calleeType.getReturnType()) return failure(); @@ -1639,7 +1639,9 @@ verifyFunctionTypeCompatibility(LLVMFunctionType callType, } FailureOr -ModuleImport::convertFunctionType(llvm::CallBase *callInst) { +ModuleImport::convertFunctionType(llvm::CallBase *callInst, + bool &isIncompatibleCall) { + isIncompatibleCall = false; auto castOrFailure = [](Type convertedType) -> FailureOr { auto funcTy = dyn_cast_or_null(convertedType); if (!funcTy) @@ -1662,11 +1664,14 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst) { if (failed(calleeType)) return failure(); - // Compare the types to avoid constructing illegal call/invoke operations. - if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) { + // Compare the types and notify users via `isIncompatibleCall` if they are not + // compatible. + if (failed(checkFunctionTypeCompatibility(*callType, *calleeType))) { + isIncompatibleCall = true; Location loc = translateLoc(callInst->getDebugLoc()); - return emitError(loc) << "incompatible call and callee types: " << *callType - << " and " << *calleeType; + emitWarning(loc) << "incompatible call and callee types: " << *callType + << " and " << *calleeType; + return callType; } return calleeType; @@ -1783,16 +1788,34 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { /*operand_attrs=*/nullptr) .getOperation(); } - FailureOr funcTy = convertFunctionType(callInst); + bool isIncompatibleCall; + FailureOr funcTy = + convertFunctionType(callInst, isIncompatibleCall); if (failed(funcTy)) return failure(); - FlatSymbolRefAttr callee = convertCalleeName(callInst); - auto callOp = builder.create(loc, *funcTy, callee, *operands); + FlatSymbolRefAttr callee = nullptr; + if (isIncompatibleCall) { + // Use an indirect call (in order to represent valid and verifiable LLVM + // IR). Build the indirect call by passing an empty `callee` operand and + // insert into `operands` to include the indirect call target. + FlatSymbolRefAttr calleeSym = convertCalleeName(callInst); + Value indirectCallVal = builder.create( + loc, LLVM::LLVMPointerType::get(context), calleeSym); + operands->insert(operands->begin(), indirectCallVal); + } else { + // Regular direct call using callee name. + callee = convertCalleeName(callInst); + } + CallOp callOp = builder.create(loc, *funcTy, callee, *operands); + if (failed(convertCallAttributes(callInst, callOp))) return failure(); - // Handle parameter and result attributes. - convertParameterAttributes(callInst, callOp, builder); + + // Handle parameter and result attributes unless it's an incompatible + // call. + if (!isIncompatibleCall) + convertParameterAttributes(callInst, callOp, builder); return callOp.getOperation(); }(); @@ -1857,12 +1880,25 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { unwindArgs))) return failure(); - FailureOr funcTy = convertFunctionType(invokeInst); + bool isIncompatibleInvoke; + FailureOr funcTy = + convertFunctionType(invokeInst, isIncompatibleInvoke); if (failed(funcTy)) return failure(); - FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst); - + FlatSymbolRefAttr calleeName = nullptr; + if (isIncompatibleInvoke) { + // Use an indirect invoke (in order to represent valid and verifiable LLVM + // IR). Build the indirect invoke by passing an empty `callee` operand and + // insert into `operands` to include the indirect invoke target. + FlatSymbolRefAttr calleeSym = convertCalleeName(invokeInst); + Value indirectInvokeVal = builder.create( + loc, LLVM::LLVMPointerType::get(context), calleeSym); + operands->insert(operands->begin(), indirectInvokeVal); + } else { + // Regular direct invoke using callee name. + calleeName = convertCalleeName(invokeInst); + } // Create the invoke operation. Normal destination block arguments will be // added later on to handle the case in which the operation result is // included in this list. @@ -1873,8 +1909,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { if (failed(convertInvokeAttributes(invokeInst, invokeOp))) return failure(); - // Handle parameter and result attributes. - convertParameterAttributes(invokeInst, invokeOp, builder); + // Handle parameter and result attributes unless it's an incompatible + // invoke. + if (!isIncompatibleInvoke) + convertParameterAttributes(invokeInst, invokeOp, builder); if (!invokeInst->getType()->isVoidTy()) mapValue(inst, invokeOp.getResults().front()); diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll index 4ef67b7190aab..66c8470b8782a 100644 --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -362,7 +362,7 @@ target datalayout = "e-m-i64:64" ; // ----- ; CHECK: -; CHECK-SAME: incompatible call and callee types: '!llvm.func' and '!llvm.func' +; CHECK-SAME: warning: incompatible call and callee types: '!llvm.func' and '!llvm.func' define void @incompatible_call_and_callee_types() { call void @callee(i64 0) ret void @@ -373,7 +373,7 @@ declare void @callee(ptr) ; // ----- ; CHECK: -; CHECK-SAME: incompatible call and callee types: '!llvm.func' and '!llvm.func' +; CHECK-SAME: warning: incompatible call and callee types: '!llvm.func' and '!llvm.func' define void @f() personality ptr @__gxx_personality_v0 { entry: invoke void @g() to label %bb1 unwind label %bb2 diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll index c294e1b34f9bb..ef3c6430b7152 100644 --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -720,3 +720,19 @@ bb2: declare void @g(...) declare i32 @__gxx_personality_v0(...) + +; // ----- + +; CHECK-LABEL: llvm.func @incompatible_call_and_callee_types +define void @incompatible_call_and_callee_types() { + ; CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i64) : i64 + ; CHECK: %[[TARGET:.*]] = llvm.mlir.addressof @callee : !llvm.ptr + ; CHECK: llvm.call %[[TARGET]](%[[CST]]) : !llvm.ptr, (i64) -> () + call void @callee(i64 0) + ; CHECK: llvm.return + ret void +} + +define void @callee({ptr, i64}, i32) { + ret void +}