Skip to content

Commit 41a6e9c

Browse files
authored
[HLSL] Implement WaveActiveAllTrue Intrinsic (#117245)
Resolves #99161 - [x] Implement `WaveActiveAllTrue` clang builtin, - [x] Link `WaveActiveAllTrue` clang builtin with `hlsl_intrinsics.h` - [x] Add sema checks for `WaveActiveAllTrue` to `CheckHLSLBuiltinFunctionCall` in `SemaChecking.cpp` - [x] Add codegen for `WaveActiveAllTrue` to `EmitHLSLBuiltinExpr` in `CGBuiltin.cpp` - [x] Add codegen tests to `clang/test/CodeGenHLSL/builtins/WaveActiveAllTrue.hlsl` - [x] Add sema tests to `clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl` - [x] Create the `int_dx_WaveActiveAllTrue` intrinsic in `IntrinsicsDirectX.td` - [x] Create the `DXILOpMapping` of `int_dx_WaveActiveAllTrue` to `114` in `DXIL.td` - [x] Create the `WaveActiveAllTrue.ll` and `WaveActiveAllTrue_errors.ll` tests in `llvm/test/CodeGen/DirectX/` - [x] Create the `int_spv_WaveActiveAllTrue` intrinsic in `IntrinsicsSPIRV.td` - [x] In SPIRVInstructionSelector.cpp create the `WaveActiveAllTrue` lowering and map it to `int_spv_WaveActiveAllTrue` in `SPIRVInstructionSelector::selectIntrinsic`. - [x] Create SPIR-V backend test case in `llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllTrue.ll`
1 parent dda1d16 commit 41a6e9c

File tree

12 files changed

+107
-0
lines changed

12 files changed

+107
-0
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4762,6 +4762,12 @@ def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
47624762
let Prototype = "void(...)";
47634763
}
47644764

4765+
def HLSLWaveActiveAllTrue : LangBuiltin<"HLSL_LANG"> {
4766+
let Spellings = ["__builtin_hlsl_wave_active_all_true"];
4767+
let Attributes = [NoThrow, Const];
4768+
let Prototype = "bool(bool)";
4769+
}
4770+
47654771
def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
47664772
let Spellings = ["__builtin_hlsl_wave_active_any_true"];
47674773
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19436,6 +19436,16 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1943619436
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
1943719437
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
1943819438
}
19439+
case Builtin::BI__builtin_hlsl_wave_active_all_true: {
19440+
Value *Op = EmitScalarExpr(E->getArg(0));
19441+
llvm::Type *Ty = Op->getType();
19442+
assert(Ty->isIntegerTy(1) &&
19443+
"Intrinsic WaveActiveAllTrue operand must be a bool");
19444+
19445+
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllTrueIntrinsic();
19446+
return EmitRuntimeCall(
19447+
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
19448+
}
1943919449
case Builtin::BI__builtin_hlsl_wave_active_any_true: {
1944019450
Value *Op = EmitScalarExpr(E->getArg(0));
1944119451
assert(Op->getType()->isIntegerTy(1) &&

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class CGHLSLRuntime {
9292
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
9393
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
9494
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
95+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
9596
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
9697
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
9798
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,15 @@ float4 trunc(float4);
22412241
// Wave* builtins
22422242
//===----------------------------------------------------------------------===//
22432243

2244+
/// \brief Returns true if the expression is true in all active lanes in the
2245+
/// current wave.
2246+
///
2247+
/// \param Val The boolean expression to evaluate.
2248+
/// \return True if the expression is true in all lanes.
2249+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2250+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_true)
2251+
__attribute__((convergent)) bool WaveActiveAllTrue(bool Val);
2252+
22442253
/// \brief Returns true if the expression is true in any active lane in the
22452254
/// current wave.
22462255
///
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call for int values.
9+
10+
// CHECK-LABEL: define {{.*}}test
11+
bool test(bool p1) {
12+
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
13+
// CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
14+
// CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all(i1 %{{[a-zA-Z0-9]+}})
15+
// CHECK: ret i1 %[[RET]]
16+
return WaveActiveAllTrue(p1);
17+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
bool test_too_few_arg() {
4+
return __builtin_hlsl_wave_active_all_true();
5+
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
6+
}
7+
8+
bool test_too_many_arg(bool p0) {
9+
return __builtin_hlsl_wave_active_all_true(p0, p0);
10+
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
11+
}
12+
13+
struct Foo
14+
{
15+
int a;
16+
};
17+
18+
bool test_type_check(Foo p0) {
19+
return __builtin_hlsl_wave_active_all_true(p0);
20+
// expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
21+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
9898
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
9999
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
100100
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
101+
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
101102
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
102103
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
103104
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ let TargetPrefix = "spv" in {
8787
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
8888
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
8989
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
90+
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
9091
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
9192
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
9293
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,14 @@ def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
924924
let stages = [Stages<DXIL1_6, [all_stages]>];
925925
}
926926

927+
def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
928+
let Doc = "returns true if the expression is true in all of the active lanes in the current wave";
929+
let intrinsics = [ IntrinSelect<int_dx_wave_all> ];
930+
let arguments = [Int1Ty];
931+
let result = Int1Ty;
932+
let stages = [Stages<DXIL1_0, [all_stages]>];
933+
}
934+
927935
def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> {
928936
let Doc = "returns true if the expression is true in any of the active lanes in the current wave";
929937
let intrinsics = [ IntrinSelect<int_dx_wave_any> ];

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2959,6 +2959,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
29592959
return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp);
29602960
case Intrinsic::spv_wave_active_countbits:
29612961
return selectWaveActiveCountBits(ResVReg, ResType, I);
2962+
case Intrinsic::spv_wave_all:
2963+
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
29622964
case Intrinsic::spv_wave_any:
29632965
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
29642966
case Intrinsic::spv_wave_is_first_lane:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define noundef i1 @wave_all_simple(i1 noundef %p1) {
4+
entry:
5+
; CHECK: call i1 @dx.op.waveAllTrue(i32 114, i1 %p1)
6+
%ret = call i1 @llvm.dx.wave.all(i1 %p1)
7+
ret i1 %ret
8+
}
9+
10+
declare i1 @llvm.dx.wave.all(i1)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#bool:]] = OpTypeBool
5+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
6+
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
7+
; CHECK-DAG: OpCapability GroupNonUniformVote
8+
9+
; CHECK-LABEL: Begin function test_wave_all
10+
define i1 @test_wave_all(i1 %p1) #0 {
11+
entry:
12+
; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
13+
; CHECK: %{{.+}} = OpGroupNonUniformAll %[[#bool]] %[[#scope]] %[[#param]]
14+
%0 = call token @llvm.experimental.convergence.entry()
15+
%ret = call i1 @llvm.spv.wave.all(i1 %p1) [ "convergencectrl"(token %0) ]
16+
ret i1 %ret
17+
}
18+
19+
declare i1 @llvm.spv.wave.all(i1) #0
20+
21+
attributes #0 = { convergent }

0 commit comments

Comments
 (0)