Skip to content

[HLSL] Implement WaveActiveAllTrue Intrinsic #117245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4762,6 +4762,12 @@ def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLWaveActiveAllTrue : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_all_true"];
let Attributes = [NoThrow, Const];
let Prototype = "bool(bool)";
}

def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_any_true"];
let Attributes = [NoThrow, Const];
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19436,6 +19436,16 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
case Builtin::BI__builtin_hlsl_wave_active_all_true: {
Value *Op = EmitScalarExpr(E->getArg(0));
llvm::Type *Ty = Op->getType();
assert(Ty->isIntegerTy(1) &&
"Intrinsic WaveActiveAllTrue operand must be a bool");

Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllTrueIntrinsic();
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
case Builtin::BI__builtin_hlsl_wave_active_any_true: {
Value *Op = EmitScalarExpr(E->getArg(0));
assert(Op->getType()->isIntegerTy(1) &&
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,15 @@ float4 trunc(float4);
// Wave* builtins
//===----------------------------------------------------------------------===//

/// \brief Returns true if the expression is true in all active lanes in the
/// current wave.
///
/// \param Val The boolean expression to evaluate.
/// \return True if the expression is true in all lanes.
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_true)
__attribute__((convergent)) bool WaveActiveAllTrue(bool Val);

/// \brief Returns true if the expression is true in any active lane in the
/// current wave.
///
Expand Down
17 changes: 17 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveAllTrue.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call for int values.

// CHECK-LABEL: define {{.*}}test
bool test(bool p1) {
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
// CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
// CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all(i1 %{{[a-zA-Z0-9]+}})
// CHECK: ret i1 %[[RET]]
return WaveActiveAllTrue(p1);
}
21 changes: 21 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

bool test_too_few_arg() {
return __builtin_hlsl_wave_active_all_true();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

bool test_too_many_arg(bool p0) {
return __builtin_hlsl_wave_active_all_true(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

struct Foo
{
int a;
};

bool test_type_check(Foo p0) {
return __builtin_hlsl_wave_active_all_true(p0);
// expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ let TargetPrefix = "spv" in {
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,14 @@ def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
let stages = [Stages<DXIL1_6, [all_stages]>];
}

def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
let Doc = "returns true if the expression is true in all of the active lanes in the current wave";
let intrinsics = [ IntrinSelect<int_dx_wave_all> ];
let arguments = [Int1Ty];
let result = Int1Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> {
let Doc = "returns true if the expression is true in any of the active lanes in the current wave";
let intrinsics = [ IntrinSelect<int_dx_wave_any> ];
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2959,6 +2959,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp);
case Intrinsic::spv_wave_active_countbits:
return selectWaveActiveCountBits(ResVReg, ResType, I);
case Intrinsic::spv_wave_all:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
case Intrinsic::spv_wave_any:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
case Intrinsic::spv_wave_is_first_lane:
Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/WaveActiveAllTrue.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

define noundef i1 @wave_all_simple(i1 noundef %p1) {
entry:
; CHECK: call i1 @dx.op.waveAllTrue(i32 114, i1 %p1)
%ret = call i1 @llvm.dx.wave.all(i1 %p1)
ret i1 %ret
}

declare i1 @llvm.dx.wave.all(i1)
21 changes: 21 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllTrue.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#bool:]] = OpTypeBool
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
; CHECK-DAG: OpCapability GroupNonUniformVote

; CHECK-LABEL: Begin function test_wave_all
define i1 @test_wave_all(i1 %p1) #0 {
entry:
; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
; CHECK: %{{.+}} = OpGroupNonUniformAll %[[#bool]] %[[#scope]] %[[#param]]
%0 = call token @llvm.experimental.convergence.entry()
%ret = call i1 @llvm.spv.wave.all(i1 %p1) [ "convergencectrl"(token %0) ]
ret i1 %ret
}

declare i1 @llvm.spv.wave.all(i1) #0

attributes #0 = { convergent }
Loading