From 1cd2b5ec97a290d8fc1e363fd1777f706d2773c5 Mon Sep 17 00:00:00 2001 From: AdUhTkJm <2292398666@qq.com> Date: Sun, 9 Mar 2025 13:55:31 +0000 Subject: [PATCH] [CIR][CUDA] Miscellanous bugfixes --- clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp | 21 +++++++---- .../Transforms/TargetLowering/Targets/X86.cpp | 2 +- clang/test/CIR/CodeGen/CUDA/simple.cu | 37 +++++++++++-------- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp index fdac639ab35a..2695e731b061 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp @@ -69,11 +69,16 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf, loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args", CharUnits::fromQuantity(16)); + mlir::Value kernelArgsDecayed = + builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs, + cir::PointerType::get(cgm.VoidPtrTy)); + // Store arguments into kernelArgs for (auto [i, arg] : llvm::enumerate(args)) { mlir::Value index = builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i)); - mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index); + mlir::Value storePos = + builder.createPtrStride(loc, kernelArgsDecayed, index); builder.CIRBaseBuilderTy::createStore( loc, cgf.GetAddrOfLocalVar(arg).getPointer(), storePos); } @@ -166,10 +171,6 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf, // mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy); CallArgList launchArgs; - mlir::Value kernelArgsDecayed = - builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs, - cir::PointerType::get(cgm.VoidPtrTy)); - launchArgs.add(RValue::get(kernel), launchFD->getParamDecl(0)->getType()); launchArgs.add( RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))), @@ -182,7 +183,8 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf, launchArgs.add( RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)), launchFD->getParamDecl(4)->getType()); - launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType()); + launchArgs.add(RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, stream)), + launchFD->getParamDecl(5)->getType()); mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType()); mlir::Operation *launchFn = @@ -219,13 +221,16 @@ RValue CIRGenCUDARuntime::emitCUDAKernelCallExpr(CIRGenFunction &cgf, cgf.emitIfOnBoolExpr( expr->getConfig(), + [&](mlir::OpBuilder &b, mlir::Location l) { + b.create(loc); + }, + loc, [&](mlir::OpBuilder &b, mlir::Location l) { CIRGenCallee callee = cgf.emitCallee(expr->getCallee()); cgf.emitCall(expr->getCallee()->getType(), callee, expr, retValue); b.create(loc); }, - loc, [](mlir::OpBuilder &b, mlir::Location l) {}, - std::optional()); + loc); return RValue::get(nullptr); } diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/X86.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/X86.cpp index 8d769808de70..0a3bbc53ca7c 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/X86.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/X86.cpp @@ -751,7 +751,7 @@ void X86_64ABIInfo::computeInfo(LowerFunctionInfo &FI) const { if (cir::MissingFeatures::vectorType()) cir_cconv_unreachable("NYI"); } else { - cir_cconv_unreachable("Indirect results are NYI"); + it->info = getIndirectResult(it->type, FreeIntRegs); } } } diff --git a/clang/test/CIR/CodeGen/CUDA/simple.cu b/clang/test/CIR/CodeGen/CUDA/simple.cu index 023089c1eb2d..d951118cbfad 100644 --- a/clang/test/CIR/CodeGen/CUDA/simple.cu +++ b/clang/test/CIR/CodeGen/CUDA/simple.cu @@ -27,14 +27,18 @@ __global__ void global_fn(int a) {} // Check for device stub emission. // CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]]) -// CIR-HOST: cir.alloca {{.*}}"kernel_args" +// CIR-HOST: %[[#CIRKernelArgs:]] = cir.alloca {{.*}}"kernel_args" +// CIR-HOST: %[[#Decayed:]] = cir.cast(array_to_ptrdecay, %[[#CIRKernelArgs]] // CIR-HOST: cir.call @__cudaPopCallConfiguration // CIR-HOST: cir.get_global @_Z24__device_stub__global_fni // CIR-HOST: cir.call @cudaLaunchKernel -// COM: LLVM-HOST: void @_Z24__device_stub__global_fni -// COM: LLVM-HOST: call i32 @__cudaPopCallConfiguration -// COM: LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni +// LLVM-HOST: void @_Z24__device_stub__global_fni +// LLVM-HOST: %[[#KernelArgs:]] = alloca [1 x ptr], i64 1, align 16 +// LLVM-HOST: %[[#GEP1:]] = getelementptr ptr, ptr %[[#KernelArgs]], i32 0 +// LLVM-HOST: %[[#GEP2:]] = getelementptr ptr, ptr %[[#GEP1]], i64 0 +// LLVM-HOST: call i32 @__cudaPopCallConfiguration +// LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni int main() { global_fn<<<1, 1>>>(1); @@ -47,19 +51,20 @@ int main() { // CIR-HOST: [[Push:%[0-9]+]] = cir.call @__cudaPushCallConfiguration // CIR-HOST: [[ConfigOK:%[0-9]+]] = cir.cast(int_to_bool, [[Push]] // CIR-HOST: cir.if [[ConfigOK]] { +// CIR-HOST: } else { // CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1> // CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]]) // CIR-HOST: } -// COM: LLVM-HOST: define dso_local i32 @main -// COM: LLVM-HOST: alloca %struct.dim3 -// COM: LLVM-HOST: alloca %struct.dim3 -// COM: LLVM-HOST: call void @_ZN4dim3C1Ejjj -// COM: LLVM-HOST: call void @_ZN4dim3C1Ejjj -// COM: LLVM-HOST: [[LLVMConfigOK:%[0-9]+]] = call i32 @__cudaPushCallConfiguration -// COM: LLVM-HOST: br [[LLVMConfigOK]], label %[[Good:[0-9]+]], label [[Bad:[0-9]+]] -// COM: LLVM-HOST: [[Good]]: -// COM: LLVM-HOST: call void @_Z24__device_stub__global_fni -// COM: LLVM-HOST: br label [[Bad]] -// COM: LLVM-HOST: [[Bad]]: -// COM: LLVM-HOST: ret i32 +// LLVM-HOST: define dso_local i32 @main +// LLVM-HOST: alloca %struct.dim3 +// LLVM-HOST: alloca %struct.dim3 +// LLVM-HOST: call void @_ZN4dim3C1Ejjj +// LLVM-HOST: call void @_ZN4dim3C1Ejjj +// LLVM-HOST: [[LLVMConfigOK:%[0-9]+]] = call i32 @__cudaPushCallConfiguration +// LLVM-HOST: br [[LLVMConfigOK]], label %[[#Good:]], label [[#Bad:]] +// LLVM-HOST: [[#Good]]: +// LLVM-HOST: br label [[#End:]] +// LLVM-HOST: [[#Bad]]: +// LLVM-HOST: call void @_Z24__device_stub__global_fni +// LLVM-HOST: br label [[#End]]