Skip to content

Commit 0832402

Browse files
committed
[CIR][CUDA] Generate device stubs
1 parent 637f2f3 commit 0832402

File tree

7 files changed

+227
-15
lines changed

7 files changed

+227
-15
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
//===----- CIRGenCUDARuntime.cpp - Interface to CUDA Runtimes -----*- C++
2+
//-*-==//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This provides an abstract class for CUDA CIR generation. Concrete
11+
// subclasses of this implement code generation for specific OpenCL
12+
// runtime libraries.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "CIRGenCUDARuntime.h"
17+
#include "CIRGenFunction.h"
18+
#include "clang/Basic/Cuda.h"
19+
#include "clang/CIR/Dialect/IR/CIRTypes.h"
20+
21+
using namespace clang;
22+
using namespace clang::CIRGen;
23+
24+
CIRGenCUDARuntime::~CIRGenCUDARuntime() {}
25+
26+
void CIRGenCUDARuntime::emitDeviceStubBody(CIRGenFunction &cgf, cir::FuncOp fn,
27+
FunctionArgList &args) {
28+
// CUDA 9.0 changed the way to launch kernels.
29+
if (!CudaFeatureEnabled(cgm.getTarget().getSDKVersion(),
30+
CudaFeature::CUDA_USES_NEW_LAUNCH))
31+
llvm_unreachable("NYI");
32+
33+
// This requires arguments to be sent to kernels in a different way.
34+
if (cgm.getLangOpts().OffloadViaLLVM)
35+
llvm_unreachable("NYI");
36+
37+
if (cgm.getLangOpts().HIP)
38+
llvm_unreachable("NYI");
39+
40+
auto &builder = cgm.getBuilder();
41+
42+
// For cudaLaunchKernel, we must add another layer of indirection
43+
// to arguments. For example, for function `add(int a, float b)`,
44+
// we need to pass it as `void *args[2] = { &a, &b }`.
45+
46+
auto loc = fn.getLoc();
47+
auto voidPtrArrayTy =
48+
cir::ArrayType::get(&cgm.getMLIRContext(), cgm.VoidPtrTy, args.size());
49+
mlir::Value kernelArgs = builder.createAlloca(
50+
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
51+
CharUnits::fromQuantity(16));
52+
53+
// Store arguments into kernelArgs
54+
for (auto [i, arg] : llvm::enumerate(args)) {
55+
mlir::Value index =
56+
builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i));
57+
mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index);
58+
builder.CIRBaseBuilderTy::createStore(
59+
loc, cgf.GetAddrOfLocalVar(arg).getPointer(), storePos);
60+
}
61+
62+
// We retrieve dim3 type by looking into the second argument of
63+
// cudaLaunchKernel, as is done in OG.
64+
TranslationUnitDecl *tuDecl = cgm.getASTContext().getTranslationUnitDecl();
65+
DeclContext *dc = TranslationUnitDecl::castToDeclContext(tuDecl);
66+
67+
// The default stream is usually stream 0 (the legacy default stream).
68+
// For per-thread default stream, we need a different LaunchKernel function.
69+
if (cgm.getLangOpts().GPUDefaultStream ==
70+
LangOptions::GPUDefaultStreamKind::PerThread)
71+
llvm_unreachable("NYI");
72+
73+
std::string launchAPI = "cudaLaunchKernel";
74+
const IdentifierInfo &launchII = cgm.getASTContext().Idents.get(launchAPI);
75+
FunctionDecl *launchFD = nullptr;
76+
for (auto *result : dc->lookup(&launchII)) {
77+
if (FunctionDecl *fd = dyn_cast<FunctionDecl>(result))
78+
launchFD = fd;
79+
}
80+
81+
if (launchFD == nullptr) {
82+
cgm.Error(cgf.CurFuncDecl->getLocation(),
83+
"Can't find declaration for " + launchAPI);
84+
return;
85+
}
86+
87+
// Use this function to retrieve arguments for cudaLaunchKernel:
88+
// int __cudaPopCallConfiguration(dim3 *gridDim, dim3 *blockDim, size_t
89+
// *sharedMem, cudaStream_t *stream)
90+
//
91+
// Here cudaStream_t, while also being the 6th argument of cudaLaunchKernel,
92+
// is a pointer to some opaque struct.
93+
94+
mlir::Type dim3Ty =
95+
cgf.getTypes().convertType(launchFD->getParamDecl(1)->getType());
96+
mlir::Type streamTy =
97+
cgf.getTypes().convertType(launchFD->getParamDecl(5)->getType());
98+
99+
mlir::Value gridDim =
100+
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
101+
"grid_dim", CharUnits::fromQuantity(8));
102+
mlir::Value blockDim =
103+
builder.createAlloca(loc, cir::PointerType::get(dim3Ty), dim3Ty,
104+
"block_dim", CharUnits::fromQuantity(8));
105+
mlir::Value sharedMem =
106+
builder.createAlloca(loc, cir::PointerType::get(cgm.SizeTy), cgm.SizeTy,
107+
"shared_mem", cgm.getSizeAlign());
108+
mlir::Value stream =
109+
builder.createAlloca(loc, cir::PointerType::get(streamTy), streamTy,
110+
"stream", cgm.getPointerAlign());
111+
112+
cir::FuncOp popConfig = cgm.createRuntimeFunction(
113+
cir::FuncType::get({gridDim.getType(), blockDim.getType(),
114+
sharedMem.getType(), stream.getType()},
115+
cgm.SInt32Ty),
116+
"__cudaPopCallConfiguration");
117+
cgf.emitRuntimeCall(loc, popConfig, {gridDim, blockDim, sharedMem, stream});
118+
119+
// Now emit the call to cudaLaunchKernel
120+
// cudaError_t cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
121+
// void **args, size_t sharedMem,
122+
// cudaStream_t stream);
123+
auto kernelTy =
124+
cir::PointerType::get(&cgm.getMLIRContext(), fn.getFunctionType());
125+
126+
mlir::Value kernel =
127+
builder.create<cir::GetGlobalOp>(loc, kernelTy, fn.getSymName());
128+
mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
129+
CallArgList launchArgs;
130+
131+
mlir::Value kernelArgsDecayed =
132+
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
133+
cir::PointerType::get(cgm.VoidPtrTy));
134+
135+
launchArgs.add(RValue::get(func), launchFD->getParamDecl(0)->getType());
136+
launchArgs.add(
137+
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
138+
launchFD->getParamDecl(1)->getType());
139+
launchArgs.add(
140+
RValue::getAggregate(Address(blockDim, CharUnits::fromQuantity(8))),
141+
launchFD->getParamDecl(2)->getType());
142+
launchArgs.add(RValue::get(kernelArgsDecayed),
143+
launchFD->getParamDecl(3)->getType());
144+
launchArgs.add(
145+
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
146+
launchFD->getParamDecl(4)->getType());
147+
launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType());
148+
149+
mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType());
150+
mlir::Operation *launchFn =
151+
cgm.createRuntimeFunction(cast<cir::FuncType>(launchTy), launchAPI);
152+
const auto &callInfo = cgm.getTypes().arrangeFunctionDeclaration(launchFD);
153+
cgf.emitCall(callInfo, CIRGenCallee::forDirect(launchFn), ReturnValueSlot(),
154+
launchArgs);
155+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===------ CIRGenCUDARuntime.h - Interface to CUDA Runtimes ------*- C++
2+
//-*-==//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This provides an abstract class for CUDA CIR generation. Concrete
11+
// subclasses of this implement code generation for specific OpenCL
12+
// runtime libraries.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#ifndef LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H
17+
#define LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H
18+
19+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
20+
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
21+
22+
namespace clang::CIRGen {
23+
24+
class CIRGenFunction;
25+
class CIRGenModule;
26+
class FunctionArgList;
27+
28+
class CIRGenCUDARuntime {
29+
protected:
30+
CIRGenModule &cgm;
31+
32+
public:
33+
CIRGenCUDARuntime(CIRGenModule &cgm) : cgm(cgm) {}
34+
virtual ~CIRGenCUDARuntime();
35+
36+
virtual void emitDeviceStubBody(CIRGenFunction &cgf, cir::FuncOp fn,
37+
FunctionArgList &args);
38+
};
39+
40+
} // namespace clang::CIRGen
41+
42+
#endif // LLVM_CLANG_LIB_CIR_CIRGENCUDARUNTIME_H

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl GD, cir::FuncOp Fn,
753753
emitConstructorBody(Args);
754754
else if (getLangOpts().CUDA && !getLangOpts().CUDAIsDevice &&
755755
FD->hasAttr<CUDAGlobalAttr>())
756-
llvm_unreachable("NYI");
756+
CGM.getCUDARuntime().emitDeviceStubBody(*this, Fn, Args);
757757
else if (isa<CXXMethodDecl>(FD) &&
758758
cast<CXXMethodDecl>(FD)->isLambdaStaticInvoker()) {
759759
// The lambda static invoker function is special, because it forwards or

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// This is the internal per-translation-unit state used for CIR translation.
1010
//
1111
//===----------------------------------------------------------------------===//
12+
#include "CIRGenCUDARuntime.h"
1213
#include "CIRGenCXXABI.h"
1314
#include "CIRGenCstEmitter.h"
1415
#include "CIRGenFunction.h"
@@ -108,7 +109,8 @@ CIRGenModule::CIRGenModule(mlir::MLIRContext &mlirContext,
108109
theModule{mlir::ModuleOp::create(builder.getUnknownLoc())}, Diags(Diags),
109110
target(astContext.getTargetInfo()), ABI(createCXXABI(*this)),
110111
genTypes{*this}, VTables{*this},
111-
openMPRuntime(new CIRGenOpenMPRuntime(*this)) {
112+
openMPRuntime(new CIRGenOpenMPRuntime(*this)),
113+
cudaRuntime(new CIRGenCUDARuntime(*this)) {
112114

113115
// Initialize CIR signed integer types cache.
114116
SInt8Ty = cir::IntType::get(&getMLIRContext(), 8, /*isSigned=*/true);

clang/lib/CIR/CodeGen/CIRGenModule.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "Address.h"
1717
#include "CIRGenBuilder.h"
18+
#include "CIRGenCUDARuntime.h"
1819
#include "CIRGenCall.h"
1920
#include "CIRGenOpenCLRuntime.h"
2021
#include "CIRGenTBAA.h"
@@ -113,6 +114,9 @@ class CIRGenModule : public CIRGenTypeCache {
113114
/// Holds the OpenMP runtime
114115
std::unique_ptr<CIRGenOpenMPRuntime> openMPRuntime;
115116

117+
/// Holds the CUDA runtime
118+
std::unique_ptr<CIRGenCUDARuntime> cudaRuntime;
119+
116120
/// Per-function codegen information. Updated everytime emitCIR is called
117121
/// for FunctionDecls's.
118122
CIRGenFunction *CurCGF = nullptr;
@@ -862,12 +866,18 @@ class CIRGenModule : public CIRGenTypeCache {
862866
/// Print out an error that codegen doesn't support the specified decl yet.
863867
void ErrorUnsupported(const Decl *D, const char *Type);
864868

865-
/// Return a reference to the configured OpenMP runtime.
869+
/// Return a reference to the configured OpenCL runtime.
866870
CIRGenOpenCLRuntime &getOpenCLRuntime() {
867871
assert(openCLRuntime != nullptr);
868872
return *openCLRuntime;
869873
}
870874

875+
/// Return a reference to the configured CUDA runtime.
876+
CIRGenCUDARuntime &getCUDARuntime() {
877+
assert(cudaRuntime != nullptr);
878+
return *cudaRuntime;
879+
}
880+
871881
void createOpenCLRuntime() {
872882
openCLRuntime.reset(new CIRGenOpenCLRuntime(*this));
873883
}

clang/lib/CIR/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_clang_library(clangCIR
1919
CIRGenClass.cpp
2020
CIRGenCleanup.cpp
2121
CIRGenCoroutine.cpp
22+
CIRGenCUDARuntime.cpp
2223
CIRGenDecl.cpp
2324
CIRGenDeclCXX.cpp
2425
CIRGenException.cpp

clang/test/CIR/CodeGen/CUDA/simple.cu

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#include "../Inputs/cuda.h"
22

33
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir \
4-
// RUN: -x cuda -emit-cir %s -o %t.cir
4+
// RUN: -x cuda -emit-cir -target-sdk-version=12.3 \
5+
// RUN: %s -o %t.cir
56
// RUN: FileCheck --check-prefix=CIR-HOST --input-file=%t.cir %s
67

78
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir \
8-
// RUN: -fcuda-is-device -emit-cir %s -o %t.cir
9+
// RUN: -fcuda-is-device -emit-cir -target-sdk-version=12.3 \
10+
// RUN: %s -o %t.cir
911
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
1012

1113
// Attribute for global_fn
12-
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fnv>{{.*}}
14+
// CIR-HOST: [[Kernel:#[a-zA-Z_0-9]+]] = {{.*}}#cir.cuda_kernel_name<_Z9global_fni>{{.*}}
1315

1416
__host__ void host_fn(int *a, int *b, int *c) {}
1517
// CIR-HOST: cir.func @_Z7host_fnPiS_S_
@@ -19,13 +21,13 @@ __device__ void device_fn(int* a, double b, float c) {}
1921
// CIR-HOST-NOT: cir.func @_Z9device_fnPidf
2022
// CIR-DEVICE: cir.func @_Z9device_fnPidf
2123

22-
#ifdef __CUDA_ARCH__
23-
__global__ void global_fn() {}
24-
#else
25-
__global__ void global_fn();
26-
#endif
27-
// CIR-HOST: @_Z24__device_stub__global_fnv(){{.*}}extra([[Kernel]])
28-
// CIR-DEVICE: @_Z9global_fnv
24+
__global__ void global_fn(int a) {}
25+
// CIR-DEVICE: @_Z9global_fni
2926

30-
// Make sure `global_fn` indeed gets emitted
31-
__host__ void x() { auto v = global_fn; }
27+
// Check for device stub emission.
28+
29+
// CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]])
30+
// CIR-HOST: cir.alloca {{.*}}"kernel_args"
31+
// CIR-HOST: cir.call @__cudaPopCallConfiguration
32+
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
33+
// CIR-HOST: cir.call @cudaLaunchKernel

0 commit comments

Comments
 (0)