Skip to content

Commit 73afef2

Browse files
committed
[HLSL] Implement WaveActiveAllTrue Intrinsic
1 parent 505e049 commit 73afef2

File tree

12 files changed

+107
-0
lines changed

12 files changed

+107
-0
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4750,6 +4750,12 @@ def HLSLAny : LangBuiltin<"HLSL_LANG"> {
47504750
let Prototype = "bool(...)";
47514751
}
47524752

4753+
def HLSLWaveActiveAllTrue : LangBuiltin<"HLSL_LANG"> {
4754+
let Spellings = ["__builtin_hlsl_wave_active_all_true"];
4755+
let Attributes = [NoThrow, Const];
4756+
let Prototype = "bool(bool)";
4757+
}
4758+
47534759
def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
47544760
let Spellings = ["__builtin_hlsl_wave_active_any_true"];
47554761
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19282,6 +19282,16 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1928219282
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
1928319283
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
1928419284
}
19285+
case Builtin::BI__builtin_hlsl_wave_active_all_true: {
19286+
Value *Op = EmitScalarExpr(E->getArg(0));
19287+
llvm::Type *Ty = Op->getType();
19288+
assert(Ty->isIntegerTy(1) &&
19289+
"Intrinsic WaveActiveAllTrue operand must be a bool");
19290+
19291+
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllTrueIntrinsic();
19292+
return EmitRuntimeCall(
19293+
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
19294+
}
1928519295
case Builtin::BI__builtin_hlsl_wave_active_any_true: {
1928619296
Value *Op = EmitScalarExpr(E->getArg(0));
1928719297
assert(Op->getType()->isIntegerTy(1) &&

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class CGHLSLRuntime {
9191
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
9292
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
9393
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
94+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
9495
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
9596
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
9697
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,6 +2223,15 @@ float4 trunc(float4);
22232223
// Wave* builtins
22242224
//===----------------------------------------------------------------------===//
22252225

2226+
/// \brief Returns true if the expression is true in all active lanes in the
2227+
/// current wave.
2228+
///
2229+
/// \param Val The boolean expression to evaluate.
2230+
/// \return True if the expression is true in all lanes.
2231+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2232+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_true)
2233+
__attribute__((convergent)) bool WaveActiveAllTrue(bool Val);
2234+
22262235
/// \brief Returns true if the expression is true in any active lane in the
22272236
/// current wave.
22282237
///
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call for int values.
9+
10+
// CHECK-LABEL: define {{.*}}test
11+
bool test(bool p1) {
12+
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
13+
// CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
14+
// CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all(i1 %{{[a-zA-Z0-9]+}})
15+
// CHECK: ret i1 %[[RET]]
16+
return WaveActiveAllTrue(p1);
17+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
bool test_too_few_arg() {
4+
return __builtin_hlsl_wave_active_all_true();
5+
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
6+
}
7+
8+
bool test_too_many_arg(bool p0) {
9+
return __builtin_hlsl_wave_active_all_true(p0, p0);
10+
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
11+
}
12+
13+
struct Foo
14+
{
15+
int a;
16+
};
17+
18+
bool test_type_check(Foo p0) {
19+
return __builtin_hlsl_wave_active_all_true(p0);
20+
// expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
21+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
9494
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
9595
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
9696
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
97+
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
9798
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
9899
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
99100
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ let TargetPrefix = "spv" in {
8686
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
8787
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
8888
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
89+
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
8990
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
9091
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
9192
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,14 @@ def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
861861
let stages = [Stages<DXIL1_6, [all_stages]>];
862862
}
863863

864+
def WaveActiveAllTrue : DXILOp<334, waveAllTrue> {
865+
let Doc = "returns true if the expression is true in all of the active lanes in the current wave";
866+
let LLVMIntrinsic = int_dx_wave_all;
867+
let arguments = [Int1Ty];
868+
let result = Int1Ty;
869+
let stages = [Stages<DXIL1_0, [all_stages]>];
870+
}
871+
864872
def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> {
865873
let Doc = "returns true if the expression is true in any of the active lanes in the current wave";
866874
let LLVMIntrinsic = int_dx_wave_any;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2830,6 +2830,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
28302830
return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp);
28312831
case Intrinsic::spv_wave_active_countbits:
28322832
return selectWaveActiveCountBits(ResVReg, ResType, I);
2833+
case Intrinsic::spv_wave_all:
2834+
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
28332835
case Intrinsic::spv_wave_any:
28342836
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
28352837
case Intrinsic::spv_wave_is_first_lane:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
2+
3+
define noundef i1 @wave_all_simple(i1 noundef %p1) {
4+
entry:
5+
; CHECK: call i1 @dx.op.waveAllTrue(i32 334, i1 %p1)
6+
%ret = call i1 @llvm.dx.wave.all(i1 %p1)
7+
ret i1 %ret
8+
}
9+
10+
declare i1 @llvm.dx.wave.all(i1)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: %[[#bool:]] = OpTypeBool
5+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
6+
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
7+
; CHECK-DAG: OpCapability GroupNonUniformVote
8+
9+
; CHECK-LABEL: Begin function test_wave_all
10+
define i1 @test_wave_all(i1 %p1) #0 {
11+
entry:
12+
; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
13+
; CHECK: %{{.+}} = OpGroupNonUniformAll %[[#bool]] %[[#scope]] %[[#param]]
14+
%0 = call token @llvm.experimental.convergence.entry()
15+
%ret = call i1 @llvm.spv.wave.all(i1 %p1) [ "convergencectrl"(token %0) ]
16+
ret i1 %ret
17+
}
18+
19+
declare i1 @llvm.spv.wave.all(i1) #0
20+
21+
attributes #0 = { convergent }

0 commit comments

Comments
 (0)