Skip to content

[MLIR] Pass count of parameters & gpu binary size to runtime wrappers #66154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
FunctionCallBuilder moduleLoadCallBuilder = {
"mgpuModuleLoad",
llvmPointerType /* void *module */,
{llvmPointerType /* void *cubin */}};
{llvmPointerType /* void *cubin */, llvmInt64Type /* size_t size */}};
FunctionCallBuilder moduleUnloadCallBuilder = {
"mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
FunctionCallBuilder moduleGetFunctionCallBuilder = {
Expand All @@ -125,7 +125,8 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
llvmInt32Type, /* unsigned int sharedMemBytes */
llvmPointerType, /* void *hstream */
llvmPointerPointerType, /* void **kernelParams */
llvmPointerPointerType /* void **extra */
llvmPointerPointerType, /* void **extra */
llvmInt64Type /* size_t paramsCount */
}};
FunctionCallBuilder streamCreateCallBuilder = {
"mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
Expand Down Expand Up @@ -1134,7 +1135,23 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
loc, rewriter, nameBuffer.str(), binaryAttr.getValue(),
LLVM::Linkage::Internal, getTypeConverter()->useOpaquePointers());

auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
// Pass the binary size. SPIRV requires binary size.
auto gpuBlob = binaryAttr.getValue();
auto gpuBlobSize = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmInt64Type,
mlir::IntegerAttr::get(llvmInt64Type,
static_cast<int64_t>(gpuBlob.size())));

auto module =
moduleLoadCallBuilder.create(loc, rewriter, {data, gpuBlobSize});

// Pass the count of the parameters to runtime wrappers
auto paramsCount = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmInt64Type,
mlir::IntegerAttr::get(
llvmInt64Type,
static_cast<int64_t>(launchOp.getNumKernelOperands())));

// Get the function from the module. The name corresponds to the name of
// the kernel function.
auto kernelName = generateKernelNameConstant(
Expand All @@ -1158,7 +1175,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
{function.getResult(), adaptor.getGridSizeX(), adaptor.getGridSizeY(),
adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
adaptor.getBlockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
/*extra=*/nullpointer});
/*extra=*/nullpointer, paramsCount});

if (launchOp.getAsyncToken()) {
// Async launch: make dependent ops use the same stream.
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ static bool cusparseLt_initiated = false;
#endif // MLIR_ENABLE_CUDA_CUSPARSELT
#endif // MLIR_ENABLE_CUDA_CUSPARSE

extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) {
extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule
mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
ScopedContext scopedContext;
CUmodule module = nullptr;
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
Expand All @@ -144,7 +145,7 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
intptr_t gridZ, intptr_t blockX, intptr_t blockY,
intptr_t blockZ, int32_t smem, CUstream stream, void **params,
void **extra) {
void **extra, size_t /*paramsCount*/) {
ScopedContext scopedContext;
int32_t maxShmem = 0;
CUdevice device = getDefaultCuDevice();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

thread_local static int32_t defaultDevice = 0;

extern "C" hipModule_t mgpuModuleLoad(void *data) {
extern "C" hipModule_t mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
hipModule_t module = nullptr;
HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
return module;
Expand All @@ -57,7 +57,7 @@ extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
intptr_t blockX, intptr_t blockY,
intptr_t blockZ, int32_t smem,
hipStream_t stream, void **params,
void **extra) {
void **extra, size_t /*paramsCount*/) {
HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
blockX, blockY, blockZ, smem,
stream, params, extra));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ module attributes {gpu.container_module} {
// CHECK: [[ADDRESSOF:%.*]] = llvm.mlir.addressof @[[GLOBAL]]
// CHECK: [[BINARY:%.*]] = llvm.getelementptr [[ADDRESSOF]]{{\[}}0, 0]
// CHECK-SAME: -> !llvm.ptr

// CHECK: [[MODULE:%.*]] = llvm.call @mgpuModuleLoad([[BINARY]])
// CHECK: [[BINARYSIZE:%.*]] = llvm.mlir.constant
// CHECK: [[MODULE:%.*]] = llvm.call @mgpuModuleLoad([[BINARY]], [[BINARYSIZE]])
// CHECK: [[PARAMSCOUNT:%.*]] = llvm.mlir.constant
// CHECK: [[FUNC:%.*]] = llvm.call @mgpuModuleGetFunction([[MODULE]], {{.*}})

// CHECK: [[STREAM:%.*]] = llvm.call @mgpuStreamCreate
Expand All @@ -56,7 +57,7 @@ module attributes {gpu.container_module} {

// CHECK: llvm.call @mgpuLaunchKernel([[FUNC]], [[C8]], [[C8]], [[C8]],
// CHECK-SAME: [[C8]], [[C8]], [[C8]], [[C256]], [[STREAM]],
// CHECK-SAME: [[PARAMS]], [[EXTRA_PARAMS]])
// CHECK-SAME: [[PARAMS]], [[EXTRA_PARAMS]], [[PARAMSCOUNT]])
// CHECK: llvm.call @mgpuStreamSynchronize
// CHECK: llvm.call @mgpuStreamDestroy
// CHECK: llvm.call @mgpuModuleUnload
Expand Down