From 388fe163fdfbd4eff20040995101dac6af64ab07 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 14 Apr 2024 14:49:28 -0400 Subject: [PATCH 1/2] fix fwd test case --- compiler/rustc_codegen_llvm/src/back/write.rs | 4 +++- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 1 + compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp | 4 ++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index a362f1640c2e0..26cea03820444 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -860,9 +860,11 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, LLVMRustEraseInstBefore(bb, last_inst); let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); + let t_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(tgt)); + let t_is_struct: bool = llvm::LLVMRustIsStructType(t_return_type); let void_type = LLVMVoidTypeInContext(llcx); // Now unwrap the struct_ret if it's actually a struct - if f_return_type != void_type { + if t_is_struct && f_return_type != void_type { let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); if num_elem_in_ret_struct == 1 { let inner_grad_name = "foo".to_string(); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 69049ca752d26..5b94b80502109 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1035,6 +1035,7 @@ extern "C" { pub fn LLVMRustEraseInstFromParent(V: &Value); pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMRustIsStructType(T: &Type) -> bool; pub fn LLVMDumpModule(M: &Module); pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; pub fn LLVMDeleteFunction(V: &Value); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 548040579b392..078c8918939b0 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -300,6 +300,10 @@ extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index, AddAttributes(F, Index, Attrs, AttrsLen); } +extern "C" bool LLVMRustIsStructType(LLVMTypeRef Ty) { + return unwrap(Ty)->isStructTy(); +} + extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, unsigned Index, LLVMAttributeRef *Attrs, From 84a180b653157fb5c015988b523a34b8af26936d Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 14 Apr 2024 15:04:37 -0400 Subject: [PATCH 2/2] simplify --- compiler/rustc_codegen_llvm/src/back/write.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 26cea03820444..657db58831c1b 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -860,11 +860,10 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, LLVMRustEraseInstBefore(bb, last_inst); let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src)); - let t_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(tgt)); - let t_is_struct: bool = llvm::LLVMRustIsStructType(t_return_type); + let f_is_struct = llvm::LLVMRustIsStructType(f_return_type); let void_type = LLVMVoidTypeInContext(llcx); // Now unwrap the struct_ret if it's actually a struct - if t_is_struct && f_return_type != void_type { + if f_is_struct { let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); if num_elem_in_ret_struct == 1 { let inner_grad_name = "foo".to_string();