Skip to content

[HLSL] Implement WaveReadLaneAt intrinsic #111010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Oct 16, 2024

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Oct 3, 2024

- create a clang built-in in Builtins.td
- add semantic checking in SemaHLSL.cpp
- link the WaveReadLaneAt api in hlsl_intrinsics.h
- add lowering to spirv backend op GroupNonUniformShuffle
  with Scope = 2 (Group) in SPIRVInstructionSelector.cpp
- add WaveReadLaneAt intrinsic to IntrinsicsDirectX.td and mapping
  to DXIL.td

- add tests for HLSL intrinsic lowering to spirv intrinsic in
  WaveReadLaneAt.hlsl
- add tests for sema checks in WaveReadLaneAt-errors.hlsl
- add spir-v backend tests in WaveReadLaneAt.ll
- add test to show scalar dxil lowering functionality

- note that this doesn't include support for the scalarizer to
  handle WaveReadLaneAt will be added in a future pr

This is the first part #70104

@inbelic inbelic changed the title [HLSL] Implement WaveReadLaneAt intrinsics [HLSL] Implement WaveReadLaneAt intrinsic for spirv backend Oct 3, 2024
@inbelic inbelic force-pushed the inbelic/wave-read-at-spirv branch from 18dd1fe to 4522f35 Compare October 3, 2024 19:41
@inbelic inbelic marked this pull request as ready for review October 3, 2024 19:57
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Oct 3, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2024

@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-backend-spir-v
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-llvm-ir

Author: Finn Plummer (inbelic)

Changes
- create a clang built-in in Builtins.td
- add semantic checking in SemaHLSL.cpp
- link the WaveReadLaneAt api in hlsl_intrinsics.h
- add lowering to spirv backend op GroupNonUniformShuffle
  with Scope = 2 (Group) in SPIRVInstructionSelector.cpp

- add tests for HLSL intrinsic lowering to spirv intrinsic in
  WaveReadLaneAt.hlsl
- add tests for sema checks in WaveReadLaneAt-errors.hlsl
- add spir-v backend tests in WaveReadLaneAt.ll

This is part 1 of 3 addressing TODO: add issue.


Full diff: https://github.com/llvm/llvm-project/pull/111010.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+16)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+7)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+20)
  • (added) clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl (+40)
  • (added) clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl (+21)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+15)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll (+28)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8090119e512fbb..eec9acd4d27d7d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool()";
 }
 
+def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_clamp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index da3eca73bfb575..dff56af9282e9d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18835,6 +18835,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
     return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+    // Due to the use of variadic arguments we must explicitly retreive them and
+    // create our function type.
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    Value *OpIndex = EmitScalarExpr(E->getArg(1));
+    llvm::FunctionType *FT = llvm::FunctionType::get(
+        OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
+        false);
+
+    // Get overloaded name
+    std::string name =
+        Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
+                           ArrayRef{OpExpr->getType()}, &CGM.getModule());
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
+                           ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..a639ce2d784f4a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 810a16d75f0228..a7bdc353ae71bf 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
 __attribute__((convergent)) bool WaveIsFirstLane();
 
+// \brief Returns the value of the expression for the given lane index within
+// the specified wave.
+template <typename T>
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
+    __attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
+
 //===----------------------------------------------------------------------===//
 // sign builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 43cc6c81ae5cb0..d54da3fd8375ed 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+
+    // Ensure index parameter type can be interpreted as a uint
+    ExprResult Index = TheCall->getArg(1);
+    QualType ArgTyIndex = Index.get()->getType();
+    if (!ArgTyIndex->hasIntegerRepresentation()) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
+      return true;
+    }
+
+    // Ensure return type is the same as the input expr type
+    ExprResult Expr = TheCall->getArg(0);
+    QualType ArgTyExpr = Expr.get()->getType();
+    TheCall->setType(ArgTyExpr);
+    break;
+  }
   case Builtin::BI__builtin_elementwise_acos:
   case Builtin::BI__builtin_elementwise_asin:
   case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
new file mode 100644
index 00000000000000..62319ebc04e2db
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr, uint idx) {
+  // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+  // CHECK-SPIRV:  %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr, uint idx) {
+  // CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
+
+  // CHECK-SPIRV:  %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+  // CHECK-DXIL:  %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+
+  // CHECK:  ret [[TY1]] %[[RET1]]
+  return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
new file mode 100644
index 00000000000000..451f2d3a563287
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg() {
+  return __builtin_hlsl_wave_read_lane_at();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+float2 test_too_few_arg_1(float2 p0) {
+  return __builtin_hlsl_wave_read_lane_at(p0);
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+float3 test_index_type_check(float3 p0, double idx) {
+  return __builtin_hlsl_wave_read_lane_at(p0, idx);
+  // expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 7ff3d58690ba75..b6ea9ce9b1411e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
+  def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7a565249a342d1..a7279193764fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
         .addUse(GR.getSPIRVTypeID(ResType))
         .addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
   }
+  case Intrinsic::spv_wave_read_lane_at: {
+    assert(I.getNumOperands() == 4);
+    assert(I.getOperand(2).isReg());
+    assert(I.getOperand(3).isReg());
+
+    // Defines the execution scope currently 2 for group, see scope table
+    SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+    return BuildMI(BB, I, I.getDebugLoc(),
+                   TII.get(SPIRV::OpGroupNonUniformShuffle))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ResType))
+        .addUse(I.getOperand(2).getReg())
+        .addUse(I.getOperand(3).getReg())
+        .addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
+  }
   case Intrinsic::spv_step:
     return selectStep(ResVReg, ResType, I);
   // Discard intrinsics which we do not expect to actually represent code after
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
new file mode 100644
index 00000000000000..e02a2907ee28a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
@@ -0,0 +1,28 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend
+
+; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:   %[[#scope:]] = OpConstant %[[#uint]] 2
+; CHECK-DAG:   %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG:   %[[#expr:]] = OpFunctionParameter %[[#f32]]
+; CHECK-DAG:   %[[#idx:]] = OpFunctionParameter %[[#uint]]
+
+define spir_func void @test_1(float %expr, i32 %idx) #0 {
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+; CHECK:   %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
+  %1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
+  ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}

@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2024

@llvm/pr-subscribers-clang

Author: Finn Plummer (inbelic)

Changes
- create a clang built-in in Builtins.td
- add semantic checking in SemaHLSL.cpp
- link the WaveReadLaneAt api in hlsl_intrinsics.h
- add lowering to spirv backend op GroupNonUniformShuffle
  with Scope = 2 (Group) in SPIRVInstructionSelector.cpp

- add tests for HLSL intrinsic lowering to spirv intrinsic in
  WaveReadLaneAt.hlsl
- add tests for sema checks in WaveReadLaneAt-errors.hlsl
- add spir-v backend tests in WaveReadLaneAt.ll

This is part 1 of 3 addressing TODO: add issue.


Full diff: https://github.com/llvm/llvm-project/pull/111010.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+16)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+7)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+20)
  • (added) clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl (+40)
  • (added) clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl (+21)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+15)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll (+28)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8090119e512fbb..eec9acd4d27d7d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool()";
 }
 
+def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_clamp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index da3eca73bfb575..dff56af9282e9d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18835,6 +18835,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
     return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+    // Due to the use of variadic arguments we must explicitly retreive them and
+    // create our function type.
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    Value *OpIndex = EmitScalarExpr(E->getArg(1));
+    llvm::FunctionType *FT = llvm::FunctionType::get(
+        OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
+        false);
+
+    // Get overloaded name
+    std::string name =
+        Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
+                           ArrayRef{OpExpr->getType()}, &CGM.getModule());
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
+                           ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..a639ce2d784f4a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 810a16d75f0228..a7bdc353ae71bf 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
 __attribute__((convergent)) bool WaveIsFirstLane();
 
+// \brief Returns the value of the expression for the given lane index within
+// the specified wave.
+template <typename T>
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
+    __attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
+
 //===----------------------------------------------------------------------===//
 // sign builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 43cc6c81ae5cb0..d54da3fd8375ed 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+
+    // Ensure index parameter type can be interpreted as a uint
+    ExprResult Index = TheCall->getArg(1);
+    QualType ArgTyIndex = Index.get()->getType();
+    if (!ArgTyIndex->hasIntegerRepresentation()) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
+      return true;
+    }
+
+    // Ensure return type is the same as the input expr type
+    ExprResult Expr = TheCall->getArg(0);
+    QualType ArgTyExpr = Expr.get()->getType();
+    TheCall->setType(ArgTyExpr);
+    break;
+  }
   case Builtin::BI__builtin_elementwise_acos:
   case Builtin::BI__builtin_elementwise_asin:
   case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
new file mode 100644
index 00000000000000..62319ebc04e2db
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr, uint idx) {
+  // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+  // CHECK-SPIRV:  %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr, uint idx) {
+  // CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
+
+  // CHECK-SPIRV:  %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+  // CHECK-DXIL:  %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+
+  // CHECK:  ret [[TY1]] %[[RET1]]
+  return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
new file mode 100644
index 00000000000000..451f2d3a563287
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg() {
+  return __builtin_hlsl_wave_read_lane_at();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+float2 test_too_few_arg_1(float2 p0) {
+  return __builtin_hlsl_wave_read_lane_at(p0);
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+float3 test_index_type_check(float3 p0, double idx) {
+  return __builtin_hlsl_wave_read_lane_at(p0, idx);
+  // expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 7ff3d58690ba75..b6ea9ce9b1411e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
+  def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7a565249a342d1..a7279193764fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
         .addUse(GR.getSPIRVTypeID(ResType))
         .addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
   }
+  case Intrinsic::spv_wave_read_lane_at: {
+    assert(I.getNumOperands() == 4);
+    assert(I.getOperand(2).isReg());
+    assert(I.getOperand(3).isReg());
+
+    // Defines the execution scope currently 2 for group, see scope table
+    SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+    return BuildMI(BB, I, I.getDebugLoc(),
+                   TII.get(SPIRV::OpGroupNonUniformShuffle))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ResType))
+        .addUse(I.getOperand(2).getReg())
+        .addUse(I.getOperand(3).getReg())
+        .addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
+  }
   case Intrinsic::spv_step:
     return selectStep(ResVReg, ResType, I);
   // Discard intrinsics which we do not expect to actually represent code after
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
new file mode 100644
index 00000000000000..e02a2907ee28a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
@@ -0,0 +1,28 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend
+
+; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:   %[[#scope:]] = OpConstant %[[#uint]] 2
+; CHECK-DAG:   %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG:   %[[#expr:]] = OpFunctionParameter %[[#f32]]
+; CHECK-DAG:   %[[#idx:]] = OpFunctionParameter %[[#uint]]
+
+define spir_func void @test_1(float %expr, i32 %idx) #0 {
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+; CHECK:   %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
+  %1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
+  ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}

@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2024

@llvm/pr-subscribers-backend-x86

Author: Finn Plummer (inbelic)

Changes
- create a clang built-in in Builtins.td
- add semantic checking in SemaHLSL.cpp
- link the WaveReadLaneAt api in hlsl_intrinsics.h
- add lowering to spirv backend op GroupNonUniformShuffle
  with Scope = 2 (Group) in SPIRVInstructionSelector.cpp

- add tests for HLSL intrinsic lowering to spirv intrinsic in
  WaveReadLaneAt.hlsl
- add tests for sema checks in WaveReadLaneAt-errors.hlsl
- add spir-v backend tests in WaveReadLaneAt.ll

This is part 1 of 3 addressing TODO: add issue.


Full diff: https://github.com/llvm/llvm-project/pull/111010.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+16)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+7)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+20)
  • (added) clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl (+40)
  • (added) clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl (+21)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+15)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll (+28)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8090119e512fbb..eec9acd4d27d7d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool()";
 }
 
+def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_clamp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index da3eca73bfb575..dff56af9282e9d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18835,6 +18835,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
     return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+    // Due to the use of variadic arguments we must explicitly retreive them and
+    // create our function type.
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    Value *OpIndex = EmitScalarExpr(E->getArg(1));
+    llvm::FunctionType *FT = llvm::FunctionType::get(
+        OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
+        false);
+
+    // Get overloaded name
+    std::string name =
+        Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
+                           ArrayRef{OpExpr->getType()}, &CGM.getModule());
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
+                           ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..a639ce2d784f4a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 810a16d75f0228..a7bdc353ae71bf 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
 __attribute__((convergent)) bool WaveIsFirstLane();
 
+// \brief Returns the value of the expression for the given lane index within
+// the specified wave.
+template <typename T>
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
+    __attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
+
 //===----------------------------------------------------------------------===//
 // sign builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 43cc6c81ae5cb0..d54da3fd8375ed 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+
+    // Ensure index parameter type can be interpreted as a uint
+    ExprResult Index = TheCall->getArg(1);
+    QualType ArgTyIndex = Index.get()->getType();
+    if (!ArgTyIndex->hasIntegerRepresentation()) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
+      return true;
+    }
+
+    // Ensure return type is the same as the input expr type
+    ExprResult Expr = TheCall->getArg(0);
+    QualType ArgTyExpr = Expr.get()->getType();
+    TheCall->setType(ArgTyExpr);
+    break;
+  }
   case Builtin::BI__builtin_elementwise_acos:
   case Builtin::BI__builtin_elementwise_asin:
   case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
new file mode 100644
index 00000000000000..62319ebc04e2db
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr, uint idx) {
+  // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+  // CHECK-SPIRV:  %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr, uint idx) {
+  // CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
+
+  // CHECK-SPIRV:  %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+  // CHECK-DXIL:  %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+
+  // CHECK:  ret [[TY1]] %[[RET1]]
+  return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
new file mode 100644
index 00000000000000..451f2d3a563287
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg() {
+  return __builtin_hlsl_wave_read_lane_at();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+float2 test_too_few_arg_1(float2 p0) {
+  return __builtin_hlsl_wave_read_lane_at(p0);
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+float3 test_index_type_check(float3 p0, double idx) {
+  return __builtin_hlsl_wave_read_lane_at(p0, idx);
+  // expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 7ff3d58690ba75..b6ea9ce9b1411e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
+  def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7a565249a342d1..a7279193764fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
         .addUse(GR.getSPIRVTypeID(ResType))
         .addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
   }
+  case Intrinsic::spv_wave_read_lane_at: {
+    assert(I.getNumOperands() == 4);
+    assert(I.getOperand(2).isReg());
+    assert(I.getOperand(3).isReg());
+
+    // Defines the execution scope currently 2 for group, see scope table
+    SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+    return BuildMI(BB, I, I.getDebugLoc(),
+                   TII.get(SPIRV::OpGroupNonUniformShuffle))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ResType))
+        .addUse(I.getOperand(2).getReg())
+        .addUse(I.getOperand(3).getReg())
+        .addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
+  }
   case Intrinsic::spv_step:
     return selectStep(ResVReg, ResType, I);
   // Discard intrinsics which we do not expect to actually represent code after
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
new file mode 100644
index 00000000000000..e02a2907ee28a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
@@ -0,0 +1,28 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend
+
+; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:   %[[#scope:]] = OpConstant %[[#uint]] 2
+; CHECK-DAG:   %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG:   %[[#expr:]] = OpFunctionParameter %[[#f32]]
+; CHECK-DAG:   %[[#idx:]] = OpFunctionParameter %[[#uint]]
+
+define spir_func void @test_1(float %expr, i32 %idx) #0 {
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+; CHECK:   %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
+  %1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
+  ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}

@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2024

@llvm/pr-subscribers-clang-codegen

Author: Finn Plummer (inbelic)

Changes
- create a clang built-in in Builtins.td
- add semantic checking in SemaHLSL.cpp
- link the WaveReadLaneAt api in hlsl_intrinsics.h
- add lowering to spirv backend op GroupNonUniformShuffle
  with Scope = 2 (Group) in SPIRVInstructionSelector.cpp

- add tests for HLSL intrinsic lowering to spirv intrinsic in
  WaveReadLaneAt.hlsl
- add tests for sema checks in WaveReadLaneAt-errors.hlsl
- add spir-v backend tests in WaveReadLaneAt.ll

This is part 1 of 3 addressing TODO: add issue.


Full diff: https://github.com/llvm/llvm-project/pull/111010.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+16)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+7)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+20)
  • (added) clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl (+40)
  • (added) clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl (+21)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+15)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll (+28)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8090119e512fbb..eec9acd4d27d7d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool()";
 }
 
+def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_clamp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index da3eca73bfb575..dff56af9282e9d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18835,6 +18835,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
     return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+    // Due to the use of variadic arguments we must explicitly retreive them and
+    // create our function type.
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    Value *OpIndex = EmitScalarExpr(E->getArg(1));
+    llvm::FunctionType *FT = llvm::FunctionType::get(
+        OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
+        false);
+
+    // Get overloaded name
+    std::string name =
+        Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
+                           ArrayRef{OpExpr->getType()}, &CGM.getModule());
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
+                           ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_sign: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..a639ce2d784f4a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 810a16d75f0228..a7bdc353ae71bf 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
 __attribute__((convergent)) bool WaveIsFirstLane();
 
+// \brief Returns the value of the expression for the given lane index within
+// the specified wave.
+template <typename T>
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
+    __attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
+
 //===----------------------------------------------------------------------===//
 // sign builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 43cc6c81ae5cb0..d54da3fd8375ed 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+
+    // Ensure index parameter type can be interpreted as a uint
+    ExprResult Index = TheCall->getArg(1);
+    QualType ArgTyIndex = Index.get()->getType();
+    if (!ArgTyIndex->hasIntegerRepresentation()) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
+      return true;
+    }
+
+    // Ensure return type is the same as the input expr type
+    ExprResult Expr = TheCall->getArg(0);
+    QualType ArgTyExpr = Expr.get()->getType();
+    TheCall->setType(ArgTyExpr);
+    break;
+  }
   case Builtin::BI__builtin_elementwise_acos:
   case Builtin::BI__builtin_elementwise_asin:
   case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
new file mode 100644
index 00000000000000..62319ebc04e2db
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr, uint idx) {
+  // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+  // CHECK-SPIRV:  %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+
+  // CHECK:  ret [[TY]] %[[RET]]
+  return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr, uint idx) {
+  // CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
+
+  // CHECK-SPIRV:  %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+  // CHECK-DXIL:  %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+
+  // CHECK:  ret [[TY1]] %[[RET1]]
+  return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
new file mode 100644
index 00000000000000..451f2d3a563287
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg() {
+  return __builtin_hlsl_wave_read_lane_at();
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+float2 test_too_few_arg_1(float2 p0) {
+  return __builtin_hlsl_wave_read_lane_at(p0);
+  // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+float3 test_index_type_check(float3 p0, double idx) {
+  return __builtin_hlsl_wave_read_lane_at(p0, idx);
+  // expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 7ff3d58690ba75..b6ea9ce9b1411e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
+  def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7a565249a342d1..a7279193764fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
         .addUse(GR.getSPIRVTypeID(ResType))
         .addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
   }
+  case Intrinsic::spv_wave_read_lane_at: {
+    assert(I.getNumOperands() == 4);
+    assert(I.getOperand(2).isReg());
+    assert(I.getOperand(3).isReg());
+
+    // Defines the execution scope currently 2 for group, see scope table
+    SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+    return BuildMI(BB, I, I.getDebugLoc(),
+                   TII.get(SPIRV::OpGroupNonUniformShuffle))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ResType))
+        .addUse(I.getOperand(2).getReg())
+        .addUse(I.getOperand(3).getReg())
+        .addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
+  }
   case Intrinsic::spv_step:
     return selectStep(ResVReg, ResType, I);
   // Discard intrinsics which we do not expect to actually represent code after
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
new file mode 100644
index 00000000000000..e02a2907ee28a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
@@ -0,0 +1,28 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend
+
+; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:   %[[#scope:]] = OpConstant %[[#uint]] 2
+; CHECK-DAG:   %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG:   %[[#expr:]] = OpFunctionParameter %[[#f32]]
+; CHECK-DAG:   %[[#idx:]] = OpFunctionParameter %[[#uint]]
+
+define spir_func void @test_1(float %expr, i32 %idx) #0 {
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+; CHECK:   %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
+  %1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
+  ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}

    - create a clang built-in in Builtins.td
    - add semantic checking in SemaHLSL.cpp
    - link the WaveReadLaneAt api in hlsl_intrinsics.h
    - add lowering to spirv backend op GroupNonUniformShuffle
      with Scope = 2 (Group) in SPIRVInstructionSelector.cpp

    - add tests for HLSL intrinsic lowering to spirv intrinsic in
      WaveReadLaneAt.hlsl
    - add tests for sema checks in WaveReadLaneAt-errors.hlsl
    - add spir-v backend tests in WaveReadLaneAt.ll
    - add WaveReadLaneAt intrinsic to IntrinsicsDirectX.td and mapping
      to DXIL.td

    - add test to show scalar functionality

    - note that this doesn't include support for the scalarizer to
      handle this function will be added in a future pr
@inbelic inbelic force-pushed the inbelic/wave-read-at-spirv branch from 4522f35 to e8c5dba Compare October 3, 2024 20:59
Copy link
Contributor

@Keenuts Keenuts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Some minor changes, otherwise looks good 😊

@inbelic inbelic changed the title [HLSL] Implement WaveReadLaneAt intrinsic for spirv backend [HLSL] Implement WaveReadLaneAt intrinsic Oct 4, 2024
    - add check for "convergencectrl" token in hlsl -> spirv intrinsic
    - correct the execution scope of the spirv instruction, add
      description
    - fix typo
@inbelic inbelic force-pushed the inbelic/wave-read-at-spirv branch from 9d7aa1d to 1b0746f Compare October 4, 2024 15:36
    - extend spirv testcase to include i32 and vector of bools
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
ArrayRef{OpExpr->getType()}, &CGM.getModule());
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
Copy link
Member

@farzonl farzonl Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string we used for BI__builtin_hlsl_wave_get_lane_index was __hlsl_wave_get_lane_index. Why would we use periods here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think __hlsl_wave_get_lane_index is the odd one out. The other intrinsics follow the pattern of hlsl.name. Having changed to using one word waveReadLaneAt I think we can keep it consistent naming with hlsl.waveReadLaneAt.

I can change the name to hlsl.waveGetLaneIndex in the clean-up pr.

    - fix variable name to be uppercase
    - remove unneeded flags from SemaHLSL testcase
    - switch to wave_readlaneat to follow llvm naming conventions
Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small nit, but otherwise looks good.

Copy link

github-actions bot commented Oct 10, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good once Chris's comments and the clang-format issue are resolved.

    - change intrinsic names to `wave_readlane` to align with AMD
      implementations
    - update semantic check to ensure that only a scalar/vector is
      allowed
    - add test case to illustrate this
@inbelic inbelic force-pushed the inbelic/wave-read-at-spirv branch from 466ae52 to 1d00f95 Compare October 11, 2024 22:44
@inbelic inbelic requested review from bogner and farzonl October 11, 2024 22:45
    - the execution scope must be the second operand and not the final
      operand of the SPIRV instruction. this was missed as I did not
      realize that SPIRV-TOOLS was not enabled on my local machine but
      luckily got caught by the spir-v github tests

    - missing testcase for int64 in directx lowering and caught that
      Int64Ty was missing from the overload types of dxilop
@inbelic inbelic force-pushed the inbelic/wave-read-at-spirv branch from 1d00f95 to 78f7e5d Compare October 11, 2024 22:53
    - add tests for int16, half and double to hlsl codegen
    - add test for vector of floats in spirv codegen to ensure vector
      register allocation
@inbelic inbelic merged commit 6d13cc9 into llvm:main Oct 16, 2024
9 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Oct 16, 2024

LLVM Buildbot has detected a new failure on builder llvm-clang-x86_64-darwin running on doug-worker-3 while building clang,llvm at step 6 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/23/builds/3882

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-all) failure: 1200 seconds without output running [b'ninja', b'check-all'], attempting to kill
...
UNSUPPORTED: mlgo-utils :: corpus/make_corpus_test.py (79030 of 79040)
UNSUPPORTED: mlgo-utils :: corpus/extract_ir_test.py (79031 of 79040)
UNSUPPORTED: mlgo-utils :: pytype.test (79032 of 79040)
PASS: lld :: MachO/x86-64-stubs.s (79033 of 79040)
PASS: lld :: MachO/zippered.yaml (79034 of 79040)
PASS: lld :: MinGW/lib.test (79035 of 79040)
PASS: lld :: MachO/weak-reference.s (79036 of 79040)
PASS: lld :: MachO/weak-definition-direct-fetch.s (79037 of 79040)
PASS: lld :: MinGW/driver.test (79038 of 79040)
PASS: lit :: shtest-define.py (79039 of 79040)
command timed out: 1200 seconds without output running [b'ninja', b'check-all'], attempting to kill
process killed by signal 9
program finished with exit code -1
elapsedTime=2553.266271

DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
- create a clang built-in in Builtins.td
    - add semantic checking in SemaHLSL.cpp
    - link the WaveReadLaneAt api in hlsl_intrinsics.h
    - add lowering to spirv backend op GroupNonUniformShuffle
      with Scope = 2 (Group) in SPIRVInstructionSelector.cpp
    - add WaveReadLaneAt intrinsic to IntrinsicsDirectX.td and mapping
      to DXIL.td

    - add tests for HLSL intrinsic lowering to spirv intrinsic in
      WaveReadLaneAt.hlsl
    - add tests for sema checks in WaveReadLaneAt-errors.hlsl
    - add spir-v backend tests in WaveReadLaneAt.ll
    - add test to show scalar dxil lowering functionality

    - note that this doesn't include support for the scalarizer to
      handle WaveReadLaneAt will be added in a future pr

This is the first part llvm#70104
@inbelic inbelic deleted the inbelic/wave-read-at-spirv branch October 21, 2024 19:45
Comment on lines +1 to +2
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this in my review. The is the wrong triple to use. You should be using spirv not spirv32.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think anything else needs to be updated. I'll open a PR to fix this myself.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:SPIR-V backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

9 participants