-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[HLSL] Implement WaveReadLaneAt
intrinsic
#111010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WaveReadLaneAt
intrinsicsWaveReadLaneAt
intrinsic for spirv backend
18dd1fe
to
4522f35
Compare
@llvm/pr-subscribers-backend-directx @llvm/pr-subscribers-llvm-ir Author: Finn Plummer (inbelic) Changes
This is part 1 of 3 addressing TODO: add issue. Full diff: https://github.com/llvm/llvm-project/pull/111010.diff 10 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8090119e512fbb..eec9acd4d27d7d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool()";
}
+def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index da3eca73bfb575..dff56af9282e9d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18835,6 +18835,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
}
+ case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+ // Due to the use of variadic arguments we must explicitly retreive them and
+ // create our function type.
+ Value *OpExpr = EmitScalarExpr(E->getArg(0));
+ Value *OpIndex = EmitScalarExpr(E->getArg(1));
+ llvm::FunctionType *FT = llvm::FunctionType::get(
+ OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
+ false);
+
+ // Get overloaded name
+ std::string name =
+ Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
+ ArrayRef{OpExpr->getType()}, &CGM.getModule());
+ return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
+ ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
+ }
case Builtin::BI__builtin_hlsl_elementwise_sign: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..a639ce2d784f4a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
//===----------------------------------------------------------------------===//
// End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 810a16d75f0228..a7bdc353ae71bf 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
__attribute__((convergent)) bool WaveIsFirstLane();
+// \brief Returns the value of the expression for the given lane index within
+// the specified wave.
+template <typename T>
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
+ __attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
+
//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 43cc6c81ae5cb0..d54da3fd8375ed 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+ if (SemaRef.checkArgCount(TheCall, 2))
+ return true;
+
+ // Ensure index parameter type can be interpreted as a uint
+ ExprResult Index = TheCall->getArg(1);
+ QualType ArgTyIndex = Index.get()->getType();
+ if (!ArgTyIndex->hasIntegerRepresentation()) {
+ SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
+ return true;
+ }
+
+ // Ensure return type is the same as the input expr type
+ ExprResult Expr = TheCall->getArg(0);
+ QualType ArgTyExpr = Expr.get()->getType();
+ TheCall->setType(ArgTyExpr);
+ break;
+ }
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
new file mode 100644
index 00000000000000..62319ebc04e2db
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr, uint idx) {
+ // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+ // CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr, uint idx) {
+ // CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
+
+ // CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+ // CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+
+ // CHECK: ret [[TY1]] %[[RET1]]
+ return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
new file mode 100644
index 00000000000000..451f2d3a563287
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_wave_read_lane_at();
+ // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+float2 test_too_few_arg_1(float2 p0) {
+ return __builtin_hlsl_wave_read_lane_at(p0);
+ // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+ return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+float3 test_index_type_check(float3 p0, double idx) {
+ return __builtin_hlsl_wave_read_lane_at(p0, idx);
+ // expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 7ff3d58690ba75..b6ea9ce9b1411e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
+ def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7a565249a342d1..a7279193764fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
}
+ case Intrinsic::spv_wave_read_lane_at: {
+ assert(I.getNumOperands() == 4);
+ assert(I.getOperand(2).isReg());
+ assert(I.getOperand(3).isReg());
+
+ // Defines the execution scope currently 2 for group, see scope table
+ SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+ return BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformShuffle))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(2).getReg())
+ .addUse(I.getOperand(3).getReg())
+ .addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
+ }
case Intrinsic::spv_step:
return selectStep(ResVReg, ResType, I);
// Discard intrinsics which we do not expect to actually represent code after
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
new file mode 100644
index 00000000000000..e02a2907ee28a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
@@ -0,0 +1,28 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 2
+; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#expr:]] = OpFunctionParameter %[[#f32]]
+; CHECK-DAG: %[[#idx:]] = OpFunctionParameter %[[#uint]]
+
+define spir_func void @test_1(float %expr, i32 %idx) #0 {
+entry:
+ %0 = call token @llvm.experimental.convergence.entry()
+; CHECK: %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
+ %1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
+ ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
|
@llvm/pr-subscribers-clang Author: Finn Plummer (inbelic) Changes
This is part 1 of 3 addressing TODO: add issue. Full diff: https://github.com/llvm/llvm-project/pull/111010.diff 10 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8090119e512fbb..eec9acd4d27d7d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool()";
}
+def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index da3eca73bfb575..dff56af9282e9d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18835,6 +18835,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
}
+ case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+ // Due to the use of variadic arguments we must explicitly retreive them and
+ // create our function type.
+ Value *OpExpr = EmitScalarExpr(E->getArg(0));
+ Value *OpIndex = EmitScalarExpr(E->getArg(1));
+ llvm::FunctionType *FT = llvm::FunctionType::get(
+ OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
+ false);
+
+ // Get overloaded name
+ std::string name =
+ Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
+ ArrayRef{OpExpr->getType()}, &CGM.getModule());
+ return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
+ ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
+ }
case Builtin::BI__builtin_hlsl_elementwise_sign: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..a639ce2d784f4a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
//===----------------------------------------------------------------------===//
// End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 810a16d75f0228..a7bdc353ae71bf 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
__attribute__((convergent)) bool WaveIsFirstLane();
+// \brief Returns the value of the expression for the given lane index within
+// the specified wave.
+template <typename T>
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
+ __attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
+
//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 43cc6c81ae5cb0..d54da3fd8375ed 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+ if (SemaRef.checkArgCount(TheCall, 2))
+ return true;
+
+ // Ensure index parameter type can be interpreted as a uint
+ ExprResult Index = TheCall->getArg(1);
+ QualType ArgTyIndex = Index.get()->getType();
+ if (!ArgTyIndex->hasIntegerRepresentation()) {
+ SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
+ return true;
+ }
+
+ // Ensure return type is the same as the input expr type
+ ExprResult Expr = TheCall->getArg(0);
+ QualType ArgTyExpr = Expr.get()->getType();
+ TheCall->setType(ArgTyExpr);
+ break;
+ }
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
new file mode 100644
index 00000000000000..62319ebc04e2db
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr, uint idx) {
+ // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+ // CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr, uint idx) {
+ // CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
+
+ // CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+ // CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+
+ // CHECK: ret [[TY1]] %[[RET1]]
+ return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
new file mode 100644
index 00000000000000..451f2d3a563287
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_wave_read_lane_at();
+ // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+float2 test_too_few_arg_1(float2 p0) {
+ return __builtin_hlsl_wave_read_lane_at(p0);
+ // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+ return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+float3 test_index_type_check(float3 p0, double idx) {
+ return __builtin_hlsl_wave_read_lane_at(p0, idx);
+ // expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 7ff3d58690ba75..b6ea9ce9b1411e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
+ def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7a565249a342d1..a7279193764fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
}
+ case Intrinsic::spv_wave_read_lane_at: {
+ assert(I.getNumOperands() == 4);
+ assert(I.getOperand(2).isReg());
+ assert(I.getOperand(3).isReg());
+
+ // Defines the execution scope currently 2 for group, see scope table
+ SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+ return BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformShuffle))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(2).getReg())
+ .addUse(I.getOperand(3).getReg())
+ .addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
+ }
case Intrinsic::spv_step:
return selectStep(ResVReg, ResType, I);
// Discard intrinsics which we do not expect to actually represent code after
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
new file mode 100644
index 00000000000000..e02a2907ee28a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
@@ -0,0 +1,28 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 2
+; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#expr:]] = OpFunctionParameter %[[#f32]]
+; CHECK-DAG: %[[#idx:]] = OpFunctionParameter %[[#uint]]
+
+define spir_func void @test_1(float %expr, i32 %idx) #0 {
+entry:
+ %0 = call token @llvm.experimental.convergence.entry()
+; CHECK: %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
+ %1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
+ ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
|
@llvm/pr-subscribers-backend-x86 Author: Finn Plummer (inbelic) Changes
This is part 1 of 3 addressing TODO: add issue. Full diff: https://github.com/llvm/llvm-project/pull/111010.diff 10 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8090119e512fbb..eec9acd4d27d7d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool()";
}
+def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index da3eca73bfb575..dff56af9282e9d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18835,6 +18835,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
}
+ case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+ // Due to the use of variadic arguments we must explicitly retreive them and
+ // create our function type.
+ Value *OpExpr = EmitScalarExpr(E->getArg(0));
+ Value *OpIndex = EmitScalarExpr(E->getArg(1));
+ llvm::FunctionType *FT = llvm::FunctionType::get(
+ OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
+ false);
+
+ // Get overloaded name
+ std::string name =
+ Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
+ ArrayRef{OpExpr->getType()}, &CGM.getModule());
+ return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
+ ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
+ }
case Builtin::BI__builtin_hlsl_elementwise_sign: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..a639ce2d784f4a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
//===----------------------------------------------------------------------===//
// End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 810a16d75f0228..a7bdc353ae71bf 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
__attribute__((convergent)) bool WaveIsFirstLane();
+// \brief Returns the value of the expression for the given lane index within
+// the specified wave.
+template <typename T>
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
+ __attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
+
//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 43cc6c81ae5cb0..d54da3fd8375ed 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+ if (SemaRef.checkArgCount(TheCall, 2))
+ return true;
+
+ // Ensure index parameter type can be interpreted as a uint
+ ExprResult Index = TheCall->getArg(1);
+ QualType ArgTyIndex = Index.get()->getType();
+ if (!ArgTyIndex->hasIntegerRepresentation()) {
+ SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
+ return true;
+ }
+
+ // Ensure return type is the same as the input expr type
+ ExprResult Expr = TheCall->getArg(0);
+ QualType ArgTyExpr = Expr.get()->getType();
+ TheCall->setType(ArgTyExpr);
+ break;
+ }
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
new file mode 100644
index 00000000000000..62319ebc04e2db
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr, uint idx) {
+ // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+ // CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr, uint idx) {
+ // CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
+
+ // CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+ // CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+
+ // CHECK: ret [[TY1]] %[[RET1]]
+ return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
new file mode 100644
index 00000000000000..451f2d3a563287
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_wave_read_lane_at();
+ // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+float2 test_too_few_arg_1(float2 p0) {
+ return __builtin_hlsl_wave_read_lane_at(p0);
+ // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+ return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+float3 test_index_type_check(float3 p0, double idx) {
+ return __builtin_hlsl_wave_read_lane_at(p0, idx);
+ // expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 7ff3d58690ba75..b6ea9ce9b1411e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
+ def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7a565249a342d1..a7279193764fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
}
+ case Intrinsic::spv_wave_read_lane_at: {
+ assert(I.getNumOperands() == 4);
+ assert(I.getOperand(2).isReg());
+ assert(I.getOperand(3).isReg());
+
+ // Defines the execution scope currently 2 for group, see scope table
+ SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+ return BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformShuffle))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(2).getReg())
+ .addUse(I.getOperand(3).getReg())
+ .addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
+ }
case Intrinsic::spv_step:
return selectStep(ResVReg, ResType, I);
// Discard intrinsics which we do not expect to actually represent code after
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
new file mode 100644
index 00000000000000..e02a2907ee28a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
@@ -0,0 +1,28 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 2
+; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#expr:]] = OpFunctionParameter %[[#f32]]
+; CHECK-DAG: %[[#idx:]] = OpFunctionParameter %[[#uint]]
+
+define spir_func void @test_1(float %expr, i32 %idx) #0 {
+entry:
+ %0 = call token @llvm.experimental.convergence.entry()
+; CHECK: %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
+ %1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
+ ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
|
@llvm/pr-subscribers-clang-codegen Author: Finn Plummer (inbelic) Changes
This is part 1 of 3 addressing TODO: add issue. Full diff: https://github.com/llvm/llvm-project/pull/111010.diff 10 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8090119e512fbb..eec9acd4d27d7d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool()";
}
+def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index da3eca73bfb575..dff56af9282e9d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18835,6 +18835,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
}
+ case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+ // Due to the use of variadic arguments we must explicitly retreive them and
+ // create our function type.
+ Value *OpExpr = EmitScalarExpr(E->getArg(0));
+ Value *OpIndex = EmitScalarExpr(E->getArg(1));
+ llvm::FunctionType *FT = llvm::FunctionType::get(
+ OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
+ false);
+
+ // Get overloaded name
+ std::string name =
+ Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
+ ArrayRef{OpExpr->getType()}, &CGM.getModule());
+ return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
+ ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
+ }
case Builtin::BI__builtin_hlsl_elementwise_sign: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
llvm::Type *Xty = Op0->getType();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index a8aabca7348ffb..a639ce2d784f4a 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
//===----------------------------------------------------------------------===//
// End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 810a16d75f0228..a7bdc353ae71bf 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
__attribute__((convergent)) bool WaveIsFirstLane();
+// \brief Returns the value of the expression for the given lane index within
+// the specified wave.
+template <typename T>
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
+ __attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
+
//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 43cc6c81ae5cb0..d54da3fd8375ed 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
+ if (SemaRef.checkArgCount(TheCall, 2))
+ return true;
+
+ // Ensure index parameter type can be interpreted as a uint
+ ExprResult Index = TheCall->getArg(1);
+ QualType ArgTyIndex = Index.get()->getType();
+ if (!ArgTyIndex->hasIntegerRepresentation()) {
+ SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
+ return true;
+ }
+
+ // Ensure return type is the same as the input expr type
+ ExprResult Expr = TheCall->getArg(0);
+ QualType ArgTyExpr = Expr.get()->getType();
+ TheCall->setType(ArgTyExpr);
+ break;
+ }
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
new file mode 100644
index 00000000000000..62319ebc04e2db
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+int test_int(int expr, uint idx) {
+ // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+ // CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
+
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+float4 test_floatv4(float4 expr, uint idx) {
+ // CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
+
+ // CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+ // CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
+
+ // CHECK: ret [[TY1]] %[[RET1]]
+ return WaveReadLaneAt(expr, idx);
+}
+
+// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
new file mode 100644
index 00000000000000..451f2d3a563287
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+bool test_too_few_arg() {
+ return __builtin_hlsl_wave_read_lane_at();
+ // expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+float2 test_too_few_arg_1(float2 p0) {
+ return __builtin_hlsl_wave_read_lane_at(p0);
+ // expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+ return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+float3 test_index_type_check(float3 p0, double idx) {
+ return __builtin_hlsl_wave_read_lane_at(p0, idx);
+ // expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 7ff3d58690ba75..b6ea9ce9b1411e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
+ def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7a565249a342d1..a7279193764fa4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
}
+ case Intrinsic::spv_wave_read_lane_at: {
+ assert(I.getNumOperands() == 4);
+ assert(I.getOperand(2).isReg());
+ assert(I.getOperand(3).isReg());
+
+ // Defines the execution scope currently 2 for group, see scope table
+ SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+ return BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformShuffle))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(I.getOperand(2).getReg())
+ .addUse(I.getOperand(3).getReg())
+ .addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
+ }
case Intrinsic::spv_step:
return selectStep(ResVReg, ResType, I);
// Discard intrinsics which we do not expect to actually represent code after
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
new file mode 100644
index 00000000000000..e02a2907ee28a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
@@ -0,0 +1,28 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 2
+; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#expr:]] = OpFunctionParameter %[[#f32]]
+; CHECK-DAG: %[[#idx:]] = OpFunctionParameter %[[#uint]]
+
+define spir_func void @test_1(float %expr, i32 %idx) #0 {
+entry:
+ %0 = call token @llvm.experimental.convergence.entry()
+; CHECK: %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
+ %1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
+ ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
|
- create a clang built-in in Builtins.td - add semantic checking in SemaHLSL.cpp - link the WaveReadLaneAt api in hlsl_intrinsics.h - add lowering to spirv backend op GroupNonUniformShuffle with Scope = 2 (Group) in SPIRVInstructionSelector.cpp - add tests for HLSL intrinsic lowering to spirv intrinsic in WaveReadLaneAt.hlsl - add tests for sema checks in WaveReadLaneAt-errors.hlsl - add spir-v backend tests in WaveReadLaneAt.ll
- add WaveReadLaneAt intrinsic to IntrinsicsDirectX.td and mapping to DXIL.td - add test to show scalar functionality - note that this doesn't include support for the scalarizer to handle this function will be added in a future pr
4522f35
to
e8c5dba
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! Some minor changes, otherwise looks good 😊
WaveReadLaneAt
intrinsic for spirv backendWaveReadLaneAt
intrinsic
- add check for "convergencectrl" token in hlsl -> spirv intrinsic - correct the execution scope of the spirv instruction, add description - fix typo
9d7aa1d
to
1b0746f
Compare
- extend spirv testcase to include i32 and vector of bools
clang/lib/CodeGen/CGBuiltin.cpp
Outdated
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(), | ||
ArrayRef{OpExpr->getType()}, &CGM.getModule()); | ||
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true), | ||
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The string we used for BI__builtin_hlsl_wave_get_lane_index
was __hlsl_wave_get_lane_index
. Why would we use periods here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think __hlsl_wave_get_lane_index
is the odd one out. The other intrinsics follow the pattern of hlsl.name
. Having changed to using one word waveReadLaneAt
I think we can keep it consistent naming with hlsl.waveReadLaneAt
.
I can change the name to hlsl.waveGetLaneIndex
in the clean-up pr.
- fix variable name to be uppercase - remove unneeded flags from SemaHLSL testcase
- switch to wave_readlaneat to follow llvm naming conventions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small nit, but otherwise looks good.
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good once Chris's comments and the clang-format issue are resolved.
- change intrinsic names to `wave_readlane` to align with AMD implementations - update semantic check to ensure that only a scalar/vector is allowed - add test case to illustrate this
466ae52
to
1d00f95
Compare
- the execution scope must be the second operand and not the final operand of the SPIRV instruction. this was missed as I did not realize that SPIRV-TOOLS was not enabled on my local machine but luckily got caught by the spir-v github tests - missing testcase for int64 in directx lowering and caught that Int64Ty was missing from the overload types of dxilop
1d00f95
to
78f7e5d
Compare
- add tests for int16, half and double to hlsl codegen - add test for vector of floats in spirv codegen to ensure vector register allocation
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/23/builds/3882 Here is the relevant piece of the build log for the reference
|
- create a clang built-in in Builtins.td - add semantic checking in SemaHLSL.cpp - link the WaveReadLaneAt api in hlsl_intrinsics.h - add lowering to spirv backend op GroupNonUniformShuffle with Scope = 2 (Group) in SPIRVInstructionSelector.cpp - add WaveReadLaneAt intrinsic to IntrinsicsDirectX.td and mapping to DXIL.td - add tests for HLSL intrinsic lowering to spirv intrinsic in WaveReadLaneAt.hlsl - add tests for sema checks in WaveReadLaneAt-errors.hlsl - add spir-v backend tests in WaveReadLaneAt.ll - add test to show scalar dxil lowering functionality - note that this doesn't include support for the scalarizer to handle WaveReadLaneAt will be added in a future pr This is the first part llvm#70104
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s | ||
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed this in my review. The is the wrong triple to use. You should be using spirv
not spirv32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think anything else needs to be updated. I'll open a PR to fix this myself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the first part #70104