Skip to content

Commit 0320a5a

Browse files
committed
[HLSL] Implement WaveReadLaneAt intrinsics
- create a clang built-in - add mapping to dxil opcode - add lowering to SPIR-V GroupNonUniformShuffle with Scope = 2 (Group) - add sema checks - add related tests
1 parent f3c408d commit 0320a5a

File tree

16 files changed

+308
-7
lines changed

16 files changed

+308
-7
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
47034703
let Prototype = "bool()";
47044704
}
47054705

4706+
def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
4707+
let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
4708+
let Attributes = [NoThrow, Const];
4709+
let Prototype = "void(...)";
4710+
}
4711+
47064712
def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
47074713
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
47084714
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18835,6 +18835,23 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1883518835
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
1883618836
return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
1883718837
}
18838+
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
18839+
// Due to the use of variadic arguments we must explicitly retreive them and
18840+
// create our function type.
18841+
Value *OpExpr = EmitScalarExpr(E->getArg(0));
18842+
Value *OpIndex = EmitScalarExpr(E->getArg(1));
18843+
llvm::FunctionType *FT = llvm::FunctionType::get(
18844+
OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
18845+
false);
18846+
18847+
// Get overloaded name
18848+
std::string name =
18849+
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
18850+
ArrayRef{OpExpr->getType()}, &CGM.getModule());
18851+
18852+
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
18853+
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
18854+
}
1883818855
case Builtin::BI__builtin_hlsl_elementwise_sign: {
1883918856
Value *Op0 = EmitScalarExpr(E->getArg(0));
1884018857
llvm::Type *Xty = Op0->getType();

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class CGHLSLRuntime {
8787
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
8888
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
8989
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
90+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)
9091

9192
//===----------------------------------------------------------------------===//
9293
// End of reserved area for HLSL intrinsic getters.

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
20152015
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
20162016
__attribute__((convergent)) bool WaveIsFirstLane();
20172017

2018+
// \brief Returns the value of the expression for the given lane index within
2019+
// the specified wave.
2020+
template <typename T>
2021+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2022+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2023+
__attribute__((convergent)) T WaveReadLaneAt(T, int32_t);
2024+
20182025
//===----------------------------------------------------------------------===//
20192026
// sign builtins
20202027
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,25 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
19561956
return true;
19571957
break;
19581958
}
1959+
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
1960+
if (SemaRef.checkArgCount(TheCall, 2))
1961+
return true;
1962+
1963+
ExprResult Index = TheCall->getArg(1);
1964+
QualType ArgTyIndex = Index.get()->getType();
1965+
if (!ArgTyIndex->hasIntegerRepresentation()) {
1966+
SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
1967+
diag::err_typecheck_convert_incompatible)
1968+
<< ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
1969+
return true;
1970+
}
1971+
1972+
// Ensure return type is the same as the input expr type
1973+
ExprResult Expr = TheCall->getArg(0);
1974+
QualType ArgTyExpr = Expr.get()->getType();
1975+
TheCall->setType(ArgTyExpr);
1976+
break;
1977+
}
19591978
case Builtin::BI__builtin_elementwise_acos:
19601979
case Builtin::BI__builtin_elementwise_asin:
19611980
case Builtin::BI__builtin_elementwise_atan:
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call.
9+
10+
// CHECK-LABEL: test_int
11+
int test_int(int expr, uint idx) {
12+
// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
13+
14+
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
15+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
16+
17+
// CHECK: ret [[TY]] %[[RET]]
18+
return WaveReadLaneAt(expr, idx);
19+
}
20+
21+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
22+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
23+
24+
// Test basic lowering to runtime function call with array and float value.
25+
26+
// CHECK-LABEL: test_floatv4
27+
float4 test_floatv4(float4 expr, uint idx) {
28+
// CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
29+
30+
// CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
31+
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
32+
33+
// CHECK: ret [[TY1]] %[[RET1]]
34+
return WaveReadLaneAt(expr, idx);
35+
}
36+
37+
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
38+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
39+
40+
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
2+
3+
bool test_too_few_arg() {
4+
return __builtin_hlsl_wave_read_lane_at();
5+
// expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
6+
}
7+
8+
float2 test_too_few_arg_1(float2 p0) {
9+
return __builtin_hlsl_wave_read_lane_at(p0);
10+
// expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
11+
}
12+
13+
float2 test_too_many_arg(float2 p0) {
14+
return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
15+
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
16+
}
17+
18+
float3 test_index_type_check(float3 p0, double idx) {
19+
return __builtin_hlsl_wave_read_lane_at(p0, idx);
20+
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
21+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
8383
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
8484
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
8585
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
86+
def int_dx_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
8687
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
8788
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
8889
}

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
8282
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
8383
[IntrNoMem, Commutative] >;
8484
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
85+
def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
8586
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
8687
}

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ class DXILOp<int opcode, DXILOpClass opclass> {
316316
// List of valid overload types predicated by DXIL version
317317
list<Overloads> overloads = [];
318318

319+
// Denote if overloads also permit vector equivalents.
320+
bit AllowVectorOverloads = 0;
321+
319322
// List of valid shader stages predicated by DXIL version
320323
list<Stages> stages;
321324

@@ -801,3 +804,14 @@ def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
801804
let stages = [Stages<DXIL1_0, [all_stages]>];
802805
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
803806
}
807+
808+
def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
809+
let Doc = "returns the value from the specified lane";
810+
let LLVMIntrinsic = int_dx_wave_read_lane_at;
811+
let arguments = [OverloadTy, Int32Ty];
812+
let result = OverloadTy;
813+
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int1Ty, Int16Ty, Int32Ty]>];
814+
let AllowVectorOverloads = 1;
815+
let stages = [Stages<DXIL1_0, [all_stages]>];
816+
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
817+
}

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ static const char *getOverloadTypeName(OverloadKind Kind) {
8585
llvm_unreachable("invalid overload type for name");
8686
}
8787

88-
static OverloadKind getOverloadKind(Type *Ty) {
88+
static OverloadKind getOverloadKind(Type *Ty,
89+
bool AllowVectorOverloads = false) {
8990
if (!Ty)
9091
return OverloadKind::VOID;
9192

@@ -126,6 +127,12 @@ static OverloadKind getOverloadKind(Type *Ty) {
126127
StructType *ST = cast<StructType>(Ty);
127128
return getOverloadKind(ST->getElementType(0));
128129
}
130+
case Type::FixedVectorTyID: {
131+
if (!AllowVectorOverloads)
132+
return OverloadKind::UNDEFINED;
133+
FixedVectorType *VT = cast<FixedVectorType>(Ty);
134+
return getOverloadKind(VT->getElementType());
135+
}
129136
default:
130137
return OverloadKind::UNDEFINED;
131138
}
@@ -157,6 +164,7 @@ struct OpCodeProperty {
157164
// Offset in DXILOpCodeClassNameTable.
158165
unsigned OpCodeClassNameOffset;
159166
llvm::SmallVector<OpOverload> Overloads;
167+
bool AllowVectorOverloads;
160168
llvm::SmallVector<OpStage> Stages;
161169
llvm::SmallVector<OpAttribute> Attributes;
162170
int OverloadParamIndex; // parameter index which control the overload.
@@ -169,13 +177,25 @@ struct OpCodeProperty {
169177
#include "DXILOperation.inc"
170178
#undef DXIL_OP_OPERATION_TABLE
171179

180+
static Twine getTypePrefix(Type *Ty) {
181+
Type::TypeID T = Ty->getTypeID();
182+
switch (T) {
183+
case Type::FixedVectorTyID: {
184+
FixedVectorType *VT = cast<FixedVectorType>(Ty);
185+
return "v" + Twine(std::to_string(VT->getNumElements()));
186+
}
187+
default:
188+
return "";
189+
}
190+
}
191+
172192
static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
173193
const OpCodeProperty &Prop) {
174194
if (Kind == OverloadKind::VOID) {
175195
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
176196
}
177197
return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
178-
getTypeName(Kind, Ty))
198+
getTypePrefix(Ty) + getTypeName(Kind, Ty))
179199
.str();
180200
}
181201

@@ -414,13 +434,15 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
414434

415435
uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys;
416436

417-
OverloadKind Kind = getOverloadKind(OverloadTy);
437+
OverloadKind Kind = getOverloadKind(OverloadTy, Prop->AllowVectorOverloads);
418438

419439
// Check if the operation supports overload types and OverloadTy is valid
420440
// per the specified types for the operation
421441
if ((ValidTyMask != OverloadKind::UNDEFINED) &&
422-
(ValidTyMask & (uint16_t)Kind) == 0)
442+
(ValidTyMask & (uint16_t)Kind) == 0) {
443+
OverloadTy->print(llvm::errs());
423444
return makeOpError(OpCode, "Invalid overload type");
445+
}
424446

425447
// Perform necessary checks to ensure Opcode is valid in the targeted shader
426448
// kind

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
26532653
.addUse(GR.getSPIRVTypeID(ResType))
26542654
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
26552655
}
2656+
case Intrinsic::spv_wave_read_lane_at: {
2657+
assert(I.getNumOperands() == 4);
2658+
assert(I.getOperand(2).isReg());
2659+
assert(I.getOperand(3).isReg());
2660+
2661+
// Defines the execution scope currently 2 for group, see scope table
2662+
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
2663+
return BuildMI(BB, I, I.getDebugLoc(),
2664+
TII.get(SPIRV::OpGroupNonUniformShuffle))
2665+
.addDef(ResVReg)
2666+
.addUse(GR.getSPIRVTypeID(ResType))
2667+
.addUse(I.getOperand(2).getReg())
2668+
.addUse(I.getOperand(3).getReg())
2669+
.addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
2670+
}
26562671
case Intrinsic::spv_step:
26572672
return selectStep(ResVReg, ResType, I);
26582673
// Discard intrinsics which we do not expect to actually represent code after
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define noundef <4 x half> @wave_rla_halfv4(<4 x half> noundef %expr, i32 noundef %idx) #0 {
4+
entry:
5+
; CHECK: call <4 x half> @dx.op.waveReadLaneAt.v4f16(i32 117, <4 x half> %expr, i32 %idx)
6+
%ret = call <4 x half> @llvm.dx.wave.read.lane.at.v4f16(<4 x half> %expr, i32 %idx)
7+
ret <4 x half> %ret
8+
}
9+
10+
define noundef <4 x float> @wave_rla_floatv4(<4 x float> noundef %expr, i32 noundef %idx) #0 {
11+
entry:
12+
; CHECK: call <4 x float> @dx.op.waveReadLaneAt.v4f32(i32 117, <4 x float> %expr, i32 %idx)
13+
%ret = call <4 x float> @llvm.dx.wave.read.lane.at.v4f32(<4 x float> %expr, i32 %idx)
14+
ret <4 x float> %ret
15+
}
16+
17+
define noundef <4 x double> @wave_rla_doublev4(<4 x double> noundef %expr, i32 noundef %idx) #0 {
18+
entry:
19+
; CHECK: call <4 x double> @dx.op.waveReadLaneAt.v4f64(i32 117, <4 x double> %expr, i32 %idx)
20+
%ret = call <4 x double> @llvm.dx.wave.read.lane.at.v4f64(<4 x double> %expr, i32 %idx)
21+
ret <4 x double> %ret
22+
}
23+
24+
define noundef <4 x i1> @wave_rla_v4i1(<4 x i1> noundef %expr, i32 noundef %idx) #0 {
25+
entry:
26+
; CHECK: call <4 x i1> @dx.op.waveReadLaneAt.v4i1(i32 117, <4 x i1> %expr, i32 %idx)
27+
%ret = call <4 x i1> @llvm.dx.wave.read.lane.at.v4i1(<4 x i1> %expr, i32 %idx)
28+
ret <4 x i1> %ret
29+
}
30+
31+
define noundef <4 x i16> @wave_rla_v4i16(<4 x i16> noundef %expr, i32 noundef %idx) #0 {
32+
entry:
33+
; CHECK: call <4 x i16> @dx.op.waveReadLaneAt.v4i16(i32 117, <4 x i16> %expr, i32 %idx)
34+
%ret = call <4 x i16> @llvm.dx.wave.read.lane.at.v4i16(<4 x i16> %expr, i32 %idx)
35+
ret <4 x i16> %ret
36+
}
37+
38+
define noundef <4 x i32> @wave_rla_v4i32(<4 x i32> noundef %expr, i32 noundef %idx) #0 {
39+
entry:
40+
; CHECK: call <4 x i32> @dx.op.waveReadLaneAt.v4i32(i32 117, <4 x i32> %expr, i32 %idx)
41+
%ret = call <4 x i32> @llvm.dx.wave.read.lane.at.v4i32(<4 x i32> %expr, i32 %idx)
42+
ret <4 x i32> %ret
43+
}
44+
45+
declare <4 x half> @llvm.dx.wave.read.lane.at.v4f16(<4 x half>, i32) #1
46+
declare <4 x float> @llvm.dx.wave.read.lane.at.v4f32(<4 x float>, i32) #1
47+
declare <4 x double> @llvm.dx.wave.read.lane.at.v4f64(<4 x double>, i32) #1
48+
49+
declare <4 x i1> @llvm.dx.wave.read.lane.at.v4i1(<4 x i1>, i32) #1
50+
declare <4 x i16> @llvm.dx.wave.read.lane.at.v4i16(<4 x i16>, i32) #1
51+
declare <4 x i32> @llvm.dx.wave.read.lane.at.v4i32(<4 x i32>, i32) #1
52+
53+
attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
54+
attributes #1 = { convergent nocallback nofree nosync nounwind willreturn }
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define noundef half @wave_rla_half(half noundef %expr, i32 noundef %idx) #0 {
4+
entry:
5+
; CHECK: call half @dx.op.waveReadLaneAt.f16(i32 117, half %expr, i32 %idx)
6+
%ret = call half @llvm.dx.wave.read.lane.at.f16(half %expr, i32 %idx)
7+
ret half %ret
8+
}
9+
10+
define noundef float @wave_rla_float(float noundef %expr, i32 noundef %idx) #0 {
11+
entry:
12+
; CHECK: call float @dx.op.waveReadLaneAt.f32(i32 117, float %expr, i32 %idx)
13+
%ret = call float @llvm.dx.wave.read.lane.at(float %expr, i32 %idx)
14+
ret float %ret
15+
}
16+
17+
define noundef double @wave_rla_double(double noundef %expr, i32 noundef %idx) #0 {
18+
entry:
19+
; CHECK: call double @dx.op.waveReadLaneAt.f64(i32 117, double %expr, i32 %idx)
20+
%ret = call double @llvm.dx.wave.read.lane.at(double %expr, i32 %idx)
21+
ret double %ret
22+
}
23+
24+
define noundef i1 @wave_rla_i1(i1 noundef %expr, i32 noundef %idx) #0 {
25+
entry:
26+
; CHECK: call i1 @dx.op.waveReadLaneAt.i1(i32 117, i1 %expr, i32 %idx)
27+
%ret = call i1 @llvm.dx.wave.read.lane.at.i1(i1 %expr, i32 %idx)
28+
ret i1 %ret
29+
}
30+
31+
define noundef i16 @wave_rla_i16(i16 noundef %expr, i32 noundef %idx) #0 {
32+
entry:
33+
; CHECK: call i16 @dx.op.waveReadLaneAt.i16(i32 117, i16 %expr, i32 %idx)
34+
%ret = call i16 @llvm.dx.wave.read.lane.at.i16(i16 %expr, i32 %idx)
35+
ret i16 %ret
36+
}
37+
38+
define noundef i32 @wave_rla_i32(i32 noundef %expr, i32 noundef %idx) #0 {
39+
entry:
40+
; CHECK: call i32 @dx.op.waveReadLaneAt.i32(i32 117, i32 %expr, i32 %idx)
41+
%ret = call i32 @llvm.dx.wave.read.lane.at.i32(i32 %expr, i32 %idx)
42+
ret i32 %ret
43+
}
44+
45+
declare half @llvm.dx.wave.read.lane.at.f16(half, i32) #1
46+
declare float @llvm.dx.wave.read.lane.at.f32(float, i32) #1
47+
declare double @llvm.dx.wave.read.lane.at.f64(double, i32) #1
48+
49+
declare i1 @llvm.dx.wave.read.lane.at.i1(i1, i32) #1
50+
declare i16 @llvm.dx.wave.read.lane.at.i16(i16, i32) #1
51+
declare i32 @llvm.dx.wave.read.lane.at.i32(i32, i32) #1
52+
53+
attributes #0 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
54+
attributes #1 = { nocallback nofree nosync nounwind willreturn }

0 commit comments

Comments
 (0)