Skip to content

[MLIR][LaunchFuncToVulkan] Remove typed pointer support #70865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -598,12 +598,6 @@ def ConvertVulkanLaunchFuncToVulkanCallsPass
This pass is only intended for the mlir-vulkan-runner.
}];

let options = [
Option<"useOpaquePointers", "use-opaque-pointers", "bool",
/*default=*/"true", "Generate LLVM IR using opaque pointers "
"instead of typed pointers">
];

let dependentDialects = ["LLVM::LLVMDialect"];
}

Expand Down
41 changes: 8 additions & 33 deletions mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,7 @@ class VulkanLaunchFuncToVulkanCallsPass
void initializeCachedTypes() {
llvmFloatType = Float32Type::get(&getContext());
llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
if (useOpaquePointers)
llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
else
llvmPointerType =
LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
llvmInt32Type = IntegerType::get(&getContext(), 32);
llvmInt64Type = IntegerType::get(&getContext(), 64);
}
Expand All @@ -85,9 +81,6 @@ class VulkanLaunchFuncToVulkanCallsPass
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
auto llvmPtrToElementType = useOpaquePointers
? llvmPointerType
: LLVM::LLVMPointerType::get(elemenType);
auto llvmArrayRankElementSizeType =
LLVM::LLVMArrayType::get(getInt64Type(), rank);

Expand All @@ -96,7 +89,7 @@ class VulkanLaunchFuncToVulkanCallsPass
// [`rank` x i64], [`rank` x i64]}">`.
return LLVM::LLVMStructType::getLiteral(
&getContext(),
{llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
{llvmPointerType, llvmPointerType, getInt64Type(),
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
}

Expand Down Expand Up @@ -280,13 +273,6 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(

auto symbolName =
llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
// Special case for fp16 type. Since it is not a supported type in C we use
// int16_t and bitcast the descriptor.
if (!useOpaquePointers && isa<Float16Type>(type)) {
auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
}
// Create call to `bindMemRef`.
builder.create<LLVM::CallOp>(
loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
Expand All @@ -303,16 +289,9 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
if (!alloca)
return failure();

LLVM::LLVMStructType llvmDescriptorTy;
if (std::optional<Type> elementType = alloca.getElemType()) {
llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
} else {
// This case is only possible if we are not using opaque pointers
// since opaque pointer producing allocas require an element type.
llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(
alloca.getRes().getType().getElementType());
}

std::optional<Type> elementType = alloca.getElemType();
assert(elementType && "expected to work with opaque pointers");
auto llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
// template <typename Elem, size_t Rank>
// struct {
// Elem *allocated;
Expand Down Expand Up @@ -379,10 +358,7 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMFunctionType::get(
getVoidType(),
{getPointerType(), getInt32Type(), getInt32Type(),
useOpaquePointers
? llvmPointerType
: LLVM::LLVMPointerType::get(getMemRefType(i, type))},
{llvmPointerType, getInt32Type(), getInt32Type(), llvmPointerType},
/*isVarArg=*/false);
builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
}
Expand Down Expand Up @@ -410,8 +386,7 @@ Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(

std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
shaderName, LLVM::Linkage::Internal,
useOpaquePointers);
shaderName, LLVM::Linkage::Internal);
}

void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
Expand All @@ -429,7 +404,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
// that data to runtime call.
Value ptrToSPIRVBinary = LLVM::createGlobalString(
loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(),
LLVM::Linkage::Internal, useOpaquePointers);
LLVM::Linkage::Internal);

// Create LLVM constant for the size of SPIR-V binary shader.
Value binarySize = builder.create<LLVM::ConstantOp>(
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=1' | FileCheck %s
// RUN: mlir-opt %s -launch-func-to-vulkan | FileCheck %s

// CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
// CHECK: llvm.mlir.global internal constant @SPIRV_BIN
Expand Down
63 changes: 0 additions & 63 deletions mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir

This file was deleted.