From 610e82aa5e53880b832c9f7307d02eb8dd94d499 Mon Sep 17 00:00:00 2001 From: AdUhTkJm <2292398666@qq.com> Date: Sat, 8 Mar 2025 17:39:19 +0000 Subject: [PATCH] [CIR][CUDA] Generate CUDA destructor --- .../Dialect/Transforms/LoweringPrepare.cpp | 67 +++++++++++++++++-- clang/test/CIR/CodeGen/CUDA/registration.cu | 28 ++++++-- 2 files changed, 86 insertions(+), 9 deletions(-) diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp index 2cac74fb9308..030327867f2b 100644 --- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp @@ -127,7 +127,7 @@ struct LoweringPreparePass : public LoweringPrepareBase { llvm::StringMap cudaKernelMap; void buildCUDAModuleCtor(); - void buildCUDAModuleDtor(); + std::optional buildCUDAModuleDtor(); std::optional buildCUDARegisterGlobals(); void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder, @@ -1153,6 +1153,23 @@ void LoweringPreparePass::buildCUDAModuleCtor() { builder.createCallOp(loc, endFunc, gpuBinaryHandle); } + // Create destructor and register it with atexit() the way NVCC does it. Doing + // it during regular destructor phase worked in CUDA before 9.2 but results in + // double-free in 9.2. + if (auto dtor = buildCUDAModuleDtor()) { + // extern "C" int atexit(void (*f)(void)); + cir::CIRBaseBuilderTy globalBuilder(getContext()); + globalBuilder.setInsertionPointToStart(theModule.getBody()); + FuncOp atexit = buildRuntimeFunction( + globalBuilder, "atexit", loc, + FuncType::get(PointerType::get(dtor->getFunctionType()), intTy)); + + mlir::Value dtorFunc = builder.create( + loc, PointerType::get(dtor->getFunctionType()), + mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr())); + builder.createCallOp(loc, atexit, dtorFunc); + } + builder.create(loc); } @@ -1256,6 +1273,51 @@ void LoweringPreparePass::buildCUDARegisterGlobalFunctions( } } +std::optional LoweringPreparePass::buildCUDAModuleDtor() { + if (!theModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName())) + return {}; + + std::string prefix = getCUDAPrefix(astCtx); + + auto voidTy = VoidType::get(&getContext()); + auto voidPtrPtrTy = PointerType::get(PointerType::get(voidTy)); + + auto loc = theModule.getLoc(); + + cir::CIRBaseBuilderTy builder(getContext()); + builder.setInsertionPointToStart(theModule.getBody()); + + // void __cudaUnregisterFatBinary(void ** handle); + std::string unregisterFuncName = + addUnderscoredPrefix(prefix, "UnregisterFatBinary"); + FuncOp unregisterFunc = buildRuntimeFunction( + builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy)); + + // void __cuda_module_dtor(); + // Despite the name, OG doesn't treat it as a destructor, so it shouldn't be + // put into globalDtorList. If it were a real dtor, then it would cause double + // free above CUDA 9.2. The way to use it is to manually call atexit() at end + // of module ctor. + std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor"); + FuncOp dtor = + buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy), + GlobalLinkageKind::InternalLinkage); + + builder.setInsertionPointToStart(dtor.addEntryBlock()); + + // For dtor, we only need to call: + // __cudaUnregisterFatBinary(__cuda_gpubin_handle); + + std::string gpubinName = addUnderscoredPrefix(prefix, "_gpubin_handle"); + auto gpubinGlobal = cast(theModule.lookupSymbol(gpubinName)); + mlir::Value gpubinAddress = builder.createGetGlobal(gpubinGlobal); + mlir::Value gpubin = builder.createLoad(loc, gpubinAddress); + builder.createCallOp(loc, unregisterFunc, gpubin); + builder.create(loc); + + return dtor; +} + void LoweringPreparePass::lowerDynamicCastOp(DynamicCastOp op) { CIRBaseBuilderTy builder(getContext()); builder.setInsertionPointAfter(op); @@ -1537,9 +1599,6 @@ void LoweringPreparePass::runOnOperation() { datalayout.emplace(theModule); } - auto typeSizeInfo = cast( - theModule->getAttr(CIRDialect::getTypeSizeInfoAttrName())); - llvm::SmallVector opsToTransform; op->walk([&](Operation *op) { diff --git a/clang/test/CIR/CodeGen/CUDA/registration.cu b/clang/test/CIR/CodeGen/CUDA/registration.cu index 4c80958efb0d..9446ce04ae4d 100644 --- a/clang/test/CIR/CodeGen/CUDA/registration.cu +++ b/clang/test/CIR/CodeGen/CUDA/registration.cu @@ -18,6 +18,15 @@ // CIR-HOST: cir.global_ctors = [#cir.global_ctor<"__cuda_module_ctor", {{[0-9]+}}>] // CIR-HOST: } +// Module destructor goes here. +// This is not a real destructor, as explained in LoweringPrepare. + +// CIR-HOST: cir.func internal private @__cuda_module_dtor() { +// CIR-HOST: %[[#HandleGlobal:]] = cir.get_global @__cuda_gpubin_handle +// CIR-HOST: %[[#Handle:]] = cir.load %0 +// CIR-HOST: cir.call @__cudaUnregisterFatBinary(%[[#Handle]]) +// CIR-HOST: } + // CIR-HOST: cir.global "private" constant cir_private @".str_Z2fnv" = // CIR-HOST-SAME: #cir.const_array<"_Z2fnv", trailing_zeros> @@ -33,6 +42,12 @@ // LLVM-HOST: } // LLVM-HOST: @llvm.global_ctors = {{.*}}ptr @__cuda_module_ctor +// LLVM-HOST: define internal void @__cuda_module_dtor() { +// LLVM-HOST: %[[#LLVMHandleVar:]] = load ptr, ptr @__cuda_gpubin_handle, align 8 +// LLVM-HOST: call void @__cudaUnregisterFatBinary(ptr %[[#LLVMHandleVar]]) +// LLVM-HOST: ret void +// LLVM-HOST: } + __global__ void fn() {} // CIR-HOST: cir.func internal private @__cuda_register_globals(%[[FatbinHandle:[a-zA-Z0-9]+]]{{.*}}) { @@ -83,12 +98,15 @@ __global__ void fn() {} // CIR-HOST: %[[#FatbinGlobal:]] = cir.get_global @__cuda_gpubin_handle // CIR-HOST: cir.store %[[#Fatbin]], %[[#FatbinGlobal]] // CIR-HOST: cir.call @__cuda_register_globals -// CIR-HOTS: cir.call @__cudaRegisterFatBinaryEnd +// CIR-HOST: cir.call @__cudaRegisterFatBinaryEnd +// CIR-HOST: %[[#ModuleDtor:]] = cir.get_global @__cuda_module_dtor +// CIR-HOST: cir.call @atexit(%[[#ModuleDtor]]) // CIR-HOST: } // LLVM-HOST: define internal void @__cuda_module_ctor() { -// LLVM-HOST: %[[#LLVMFatbin:]] = call ptr @__cudaRegisterFatBinary(ptr @__cuda_fatbin_wrapper) -// LLVM-HOST: store ptr %[[#LLVMFatbin]], ptr @__cuda_gpubin_handle -// LLVM-HOST: call void @__cuda_register_globals -// LLVM-HOST: call void @__cudaRegisterFatBinaryEnd +// LLVM-HOST: %[[#LLVMFatbin:]] = call ptr @__cudaRegisterFatBinary(ptr @__cuda_fatbin_wrapper) +// LLVM-HOST: store ptr %[[#LLVMFatbin]], ptr @__cuda_gpubin_handle +// LLVM-HOST: call void @__cuda_register_globals +// LLVM-HOST: call void @__cudaRegisterFatBinaryEnd +// LLVM-HOST: call i32 @atexit(ptr @__cuda_module_dtor) // LLVM-HOST: }