Skip to content

[MLIR][LaunchFuncToVulkan] Remove typed pointer support #70865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2023

Conversation

Dinistro
Copy link
Contributor

This commit removes the typed pointer support from the LaunchFunc's lowering to Vukan dialect. Typed pointers have been deprecated for a while now and it's planned to soon remove them from the LLVM dialect.

Related PSA: https://discourse.llvm.org/t/psa-removal-of-typed-pointers-from-the-llvm-dialect/74502

This commit removes the typed pointer support from the LaunchFunc's
lowering to Vukan dialect. Typed pointers have been deprecated for a
while now and it's planned to soon remove them from the LLVM dialect.

Related PSA: https://discourse.llvm.org/t/psa-removal-of-typed-pointers-from-the-llvm-dialect/74502
@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Christian Ulmann (Dinistro)

Changes

This commit removes the typed pointer support from the LaunchFunc's lowering to Vukan dialect. Typed pointers have been deprecated for a while now and it's planned to soon remove them from the LLVM dialect.

Related PSA: https://discourse.llvm.org/t/psa-removal-of-typed-pointers-from-the-llvm-dialect/74502


Full diff: https://github.com/llvm/llvm-project/pull/70865.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (-6)
  • (modified) mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp (+8-33)
  • (modified) mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir (+1-1)
  • (removed) mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir (-63)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 036c9b0039779ab..c2d3d0e134f961e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -598,12 +598,6 @@ def ConvertVulkanLaunchFuncToVulkanCallsPass
     This pass is only intended for the mlir-vulkan-runner.
   }];
 
-  let options = [
-     Option<"useOpaquePointers", "use-opaque-pointers", "bool",
-            /*default=*/"true", "Generate LLVM IR using opaque pointers "
-            "instead of typed pointers">
-  ];
-
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 036eb0200ca7176..938db549630680a 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -65,11 +65,7 @@ class VulkanLaunchFuncToVulkanCallsPass
   void initializeCachedTypes() {
     llvmFloatType = Float32Type::get(&getContext());
     llvmVoidType = LLVM::LLVMVoidType::get(&getContext());
-    if (useOpaquePointers)
-      llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
-    else
-      llvmPointerType =
-          LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8));
+    llvmPointerType = LLVM::LLVMPointerType::get(&getContext());
     llvmInt32Type = IntegerType::get(&getContext(), 32);
     llvmInt64Type = IntegerType::get(&getContext(), 64);
   }
@@ -85,9 +81,6 @@ class VulkanLaunchFuncToVulkanCallsPass
     //   int64_t sizes[Rank]; // omitted when rank == 0
     //   int64_t strides[Rank]; // omitted when rank == 0
     // };
-    auto llvmPtrToElementType = useOpaquePointers
-                                    ? llvmPointerType
-                                    : LLVM::LLVMPointerType::get(elemenType);
     auto llvmArrayRankElementSizeType =
         LLVM::LLVMArrayType::get(getInt64Type(), rank);
 
@@ -96,7 +89,7 @@ class VulkanLaunchFuncToVulkanCallsPass
     // [`rank` x i64], [`rank` x i64]}">`.
     return LLVM::LLVMStructType::getLiteral(
         &getContext(),
-        {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
+        {llvmPointerType, llvmPointerType, getInt64Type(),
          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
   }
 
@@ -280,13 +273,6 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
 
     auto symbolName =
         llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
-    // Special case for fp16 type. Since it is not a supported type in C we use
-    // int16_t and bitcast the descriptor.
-    if (!useOpaquePointers && isa<Float16Type>(type)) {
-      auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
-      ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
-          loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
-    }
     // Create call to `bindMemRef`.
     builder.create<LLVM::CallOp>(
         loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()),
@@ -303,16 +289,9 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
   if (!alloca)
     return failure();
 
-  LLVM::LLVMStructType llvmDescriptorTy;
-  if (std::optional<Type> elementType = alloca.getElemType()) {
-    llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
-  } else {
-    // This case is only possible if we are not using opaque pointers
-    // since opaque pointer producing allocas require an element type.
-    llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(
-        alloca.getRes().getType().getElementType());
-  }
-
+  std::optional<Type> elementType = alloca.getElemType();
+  assert(elementType && "expected to work with opaque pointers");
+  auto llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
   // template <typename Elem, size_t Rank>
   // struct {
   //   Elem *allocated;
@@ -379,10 +358,7 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
       if (!module.lookupSymbol(fnName)) {
         auto fnType = LLVM::LLVMFunctionType::get(
             getVoidType(),
-            {getPointerType(), getInt32Type(), getInt32Type(),
-             useOpaquePointers
-                 ? llvmPointerType
-                 : LLVM::LLVMPointerType::get(getMemRefType(i, type))},
+            {llvmPointerType, getInt32Type(), getInt32Type(), llvmPointerType},
             /*isVarArg=*/false);
         builder.create<LLVM::LLVMFuncOp>(loc, fnName, fnType);
       }
@@ -410,8 +386,7 @@ Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
 
   std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
   return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
-                                  shaderName, LLVM::Linkage::Internal,
-                                  useOpaquePointers);
+                                  shaderName, LLVM::Linkage::Internal);
 }
 
 void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
@@ -429,7 +404,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
   // that data to runtime call.
   Value ptrToSPIRVBinary = LLVM::createGlobalString(
       loc, builder, kSPIRVBinary, spirvAttributes.blob.getValue(),
-      LLVM::Linkage::Internal, useOpaquePointers);
+      LLVM::Linkage::Internal);
 
   // Create LLVM constant for the size of SPIR-V binary shader.
   Value binarySize = builder.create<LLVM::ConstantOp>(
diff --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
index f6da6d0a3be3476..fe8a36ee29a9f08 100644
--- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
+++ b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=1' | FileCheck %s
+// RUN: mlir-opt %s -launch-func-to-vulkan | FileCheck %s
 
 // CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
 // CHECK: llvm.mlir.global internal constant @SPIRV_BIN
diff --git a/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir b/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir
deleted file mode 100644
index 2884b33750b32ce..000000000000000
--- a/mlir/test/Conversion/GPUToVulkan/typed-pointers.mlir
+++ /dev/null
@@ -1,63 +0,0 @@
-// RUN: mlir-opt %s -launch-func-to-vulkan='use-opaque-pointers=0' | FileCheck %s
-
-// CHECK: llvm.mlir.global internal constant @kernel_spv_entry_point_name
-// CHECK: llvm.mlir.global internal constant @SPIRV_BIN
-// CHECK: %[[Vulkan_Runtime_ptr:.*]] = llvm.call @initVulkan() : () -> !llvm.ptr<i8>
-// CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN
-// CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]]
-// CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant
-// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
-// CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, i32) -> ()
-// CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name
-// CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]]
-// CHECK: llvm.call @setEntryPoint(%[[Vulkan_Runtime_ptr]], %[[entry_point_ptr]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @setNumWorkGroups(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, i64, i64, i64) -> ()
-// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
-// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> ()
-
-// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, i32, i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
-
-module attributes {gpu.container_module} {
-  llvm.func @malloc(i64) -> !llvm.ptr<i8>
-  llvm.func @foo() {
-    %0 = llvm.mlir.constant(12 : index) : i64
-    %1 = llvm.mlir.zero : !llvm.ptr<f32>
-    %2 = llvm.mlir.constant(1 : index) : i64
-    %3 = llvm.getelementptr %1[%2] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-    %4 = llvm.ptrtoint %3 : !llvm.ptr<f32> to i64
-    %5 = llvm.mul %0, %4 : i64
-    %6 = llvm.call @malloc(%5) : (i64) -> !llvm.ptr<i8>
-    %7 = llvm.bitcast %6 : !llvm.ptr<i8> to !llvm.ptr<f32>
-    %8 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %9 = llvm.insertvalue %7, %8[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %10 = llvm.insertvalue %7, %9[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %11 = llvm.mlir.constant(0 : index) : i64
-    %12 = llvm.insertvalue %11, %10[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %13 = llvm.mlir.constant(1 : index) : i64
-    %14 = llvm.insertvalue %0, %12[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %15 = llvm.insertvalue %13, %14[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %16 = llvm.mlir.constant(1 : index) : i64
-    %17 = llvm.extractvalue %15[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %18 = llvm.extractvalue %15[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %19 = llvm.extractvalue %15[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %20 = llvm.extractvalue %15[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %21 = llvm.extractvalue %15[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    llvm.call @vulkanLaunch(%16, %16, %16, %17, %18, %19, %20, %21) {spirv_blob = "\03\02#\07\00", spirv_element_types = [f32], spirv_entry_point = "kernel"}
-    : (i64, i64, i64, !llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64) -> ()
-    llvm.return
-  }
-  llvm.func @vulkanLaunch(%arg0: i64, %arg1: i64, %arg2: i64, %arg6: !llvm.ptr<f32>, %arg7: !llvm.ptr<f32>, %arg8: i64, %arg9: i64, %arg10: i64) {
-    %0 = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %1 = llvm.insertvalue %arg6, %0[0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %2 = llvm.insertvalue %arg7, %1[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %3 = llvm.insertvalue %arg8, %2[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-    %6 = llvm.mlir.constant(1 : index) : i64
-    %7 = llvm.alloca %6 x !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
-    llvm.store %5, %7 : !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>
-    llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %7) : (i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>) -> ()
-    llvm.return
-  }
-  llvm.func @_mlir_ciface_vulkanLaunch(i64, i64, i64, !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>>)
-}

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Dinistro Dinistro merged commit fd8be1e into llvm:main Nov 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants