Skip to content

Commit 358ea82

Browse files
committed
[HLSL] Implement WaveReadLaneAt intrinsic
- create a clang built-in in Builtins.td - add semantic checking in SemaHLSL.cpp - link the WaveReadLaneAt api in hlsl_intrinsics.h - add lowering to spirv backend op GroupNonUniformShuffle with Scope = 2 (Group) in SPIRVInstructionSelector.cpp - add tests for HLSL intrinsic lowering to spirv intrinsic in WaveReadLaneAt.hlsl - add tests for sema checks in WaveReadLaneAt-errors.hlsl - add spir-v backend tests in WaveReadLaneAt.ll
1 parent f3c408d commit 358ea82

File tree

10 files changed

+156
-0
lines changed

10 files changed

+156
-0
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
47034703
let Prototype = "bool()";
47044704
}
47054705

4706+
def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
4707+
let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
4708+
let Attributes = [NoThrow, Const];
4709+
let Prototype = "void(...)";
4710+
}
4711+
47064712
def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
47074713
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
47084714
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18835,6 +18835,23 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1883518835
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
1883618836
return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
1883718837
}
18838+
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
18839+
// Due to the use of variadic arguments we must explicitly retreive them and
18840+
// create our function type.
18841+
Value *OpExpr = EmitScalarExpr(E->getArg(0));
18842+
Value *OpIndex = EmitScalarExpr(E->getArg(1));
18843+
llvm::FunctionType *FT = llvm::FunctionType::get(
18844+
OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
18845+
false);
18846+
18847+
// Get overloaded name
18848+
std::string name =
18849+
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
18850+
ArrayRef{OpExpr->getType()}, &CGM.getModule());
18851+
return EmitRuntimeCall(
18852+
CGM.CreateRuntimeFunction(FT, name, {}, false, true),
18853+
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
18854+
}
1883818855
case Builtin::BI__builtin_hlsl_elementwise_sign: {
1883918856
Value *Op0 = EmitScalarExpr(E->getArg(0));
1884018857
llvm::Type *Xty = Op0->getType();

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
8787
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
8888
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
8989
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
90+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
9091

9192
//===----------------------------------------------------------------------===//
9293
// End of reserved area for HLSL intrinsic getters.

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
20152015
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
20162016
__attribute__((convergent)) bool WaveIsFirstLane();
20172017

2018+
// \brief Returns the value of the expression for the given lane index within
2019+
// the specified wave.
2020+
template <typename T>
2021+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2022+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2023+
__attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
2024+
20182025
//===----------------------------------------------------------------------===//
20192026
// sign builtins
20202027
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
19561956
return true;
19571957
break;
19581958
}
1959+
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
1960+
if (SemaRef.checkArgCount(TheCall, 2))
1961+
return true;
1962+
1963+
// Ensure index parameter type can be interpreted as a uint
1964+
ExprResult Index = TheCall->getArg(1);
1965+
QualType ArgTyIndex = Index.get()->getType();
1966+
if (!ArgTyIndex->hasIntegerRepresentation()) {
1967+
SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
1968+
diag::err_typecheck_convert_incompatible)
1969+
<< ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
1970+
return true;
1971+
}
1972+
1973+
// Ensure return type is the same as the input expr type
1974+
ExprResult Expr = TheCall->getArg(0);
1975+
QualType ArgTyExpr = Expr.get()->getType();
1976+
TheCall->setType(ArgTyExpr);
1977+
break;
1978+
}
19591979
case Builtin::BI__builtin_elementwise_acos:
19601980
case Builtin::BI__builtin_elementwise_asin:
19611981
case Builtin::BI__builtin_elementwise_atan:
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call.
9+
10+
// CHECK-LABEL: test_int
11+
int test_int(int expr, uint idx) {
12+
// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
13+
14+
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
15+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
16+
17+
// CHECK: ret [[TY]] %[[RET]]
18+
return WaveReadLaneAt(expr, idx);
19+
}
20+
21+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
22+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
23+
24+
// Test basic lowering to runtime function call with array and float value.
25+
26+
// CHECK-LABEL: test_floatv4
27+
float4 test_floatv4(float4 expr, uint idx) {
28+
// CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
29+
30+
// CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
31+
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
32+
33+
// CHECK: ret [[TY1]] %[[RET1]]
34+
return WaveReadLaneAt(expr, idx);
35+
}
36+
37+
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
38+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
39+
40+
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
2+
3+
bool test_too_few_arg() {
4+
return __builtin_hlsl_wave_read_lane_at();
5+
// expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
6+
}
7+
8+
float2 test_too_few_arg_1(float2 p0) {
9+
return __builtin_hlsl_wave_read_lane_at(p0);
10+
// expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
11+
}
12+
13+
float2 test_too_many_arg(float2 p0) {
14+
return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
15+
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
16+
}
17+
18+
float3 test_index_type_check(float3 p0, double idx) {
19+
return __builtin_hlsl_wave_read_lane_at(p0, idx);
20+
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
21+
}

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
8282
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
8383
[IntrNoMem, Commutative] >;
8484
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
85+
def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
8586
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
8687
}

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
26532653
.addUse(GR.getSPIRVTypeID(ResType))
26542654
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
26552655
}
2656+
case Intrinsic::spv_wave_read_lane_at: {
2657+
assert(I.getNumOperands() == 4);
2658+
assert(I.getOperand(2).isReg());
2659+
assert(I.getOperand(3).isReg());
2660+
2661+
// Defines the execution scope currently 2 for group, see scope table
2662+
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
2663+
return BuildMI(BB, I, I.getDebugLoc(),
2664+
TII.get(SPIRV::OpGroupNonUniformShuffle))
2665+
.addDef(ResVReg)
2666+
.addUse(GR.getSPIRVTypeID(ResType))
2667+
.addUse(I.getOperand(2).getReg())
2668+
.addUse(I.getOperand(3).getReg())
2669+
.addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
2670+
}
26562671
case Intrinsic::spv_step:
26572672
return selectStep(ResVReg, ResType, I);
26582673
// Discard intrinsics which we do not expect to actually represent code after
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; Test lowering to spir-v backend
5+
6+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
7+
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 2
8+
; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32
9+
; CHECK-DAG: %[[#expr:]] = OpFunctionParameter %[[#f32]]
10+
; CHECK-DAG: %[[#idx:]] = OpFunctionParameter %[[#uint]]
11+
12+
define spir_func void @test_1(float %expr, i32 %idx) #0 {
13+
entry:
14+
%0 = call token @llvm.experimental.convergence.entry()
15+
; CHECK: %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
16+
%1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
17+
ret void
18+
}
19+
20+
declare i32 @__hlsl_wave_get_lane_index() #1
21+
22+
attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
23+
attributes #1 = { convergent }
24+
25+
!llvm.module.flags = !{!0, !1}
26+
27+
!0 = !{i32 1, !"wchar_size", i32 4}
28+
!1 = !{i32 4, !"dx.disable_optimizations", i32 1}

0 commit comments

Comments
 (0)