Skip to content

[HLSL] Implement WaveActiveAnyTrue intrinsic #115902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4750,6 +4750,12 @@ def HLSLAny : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool(...)";
}

def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_any_true"];
let Attributes = [NoThrow, Const];
let Prototype = "bool(bool)";
}

def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
let Attributes = [NoThrow, Const];
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19282,6 +19282,16 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
case Builtin::BI__builtin_hlsl_wave_active_any_true: {
Value *Op = EmitScalarExpr(E->getArg(0));
llvm::Type *Ty = Op->getType();
assert(Ty->isIntegerTy(1) &&
"Intrinsic WaveActiveAnyTrue operand must be a bool");

Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAnyTrueIntrinsic();
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2223,6 +2223,15 @@ float4 trunc(float4);
// Wave* builtins
//===----------------------------------------------------------------------===//

/// \brief Returns true if the expression is true in any active lane in the
/// current wave.
///
/// \param Val The boolean expression to evaluate.
/// \return True if the expression is true in any lane.
_HLSL_AVAILABILITY(shadermodel, 6.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is correct. And I can see the other wave intrinsics do this, but our default availability is 6.0. Is the thinking to do this that do the bookkeeping now so that it makes it easier to support pre 6.0 shader models i.e. FXC shaders at some point?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had propagated this for the sake of user documentation as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just kind of following what others had done. But imo its worth keeping if for nothing else besides the documentation

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_any_true)
__attribute__((convergent)) bool WaveActiveAnyTrue(bool Val);

/// \brief Counts the number of boolean variables which evaluate to true across
/// all active lanes in the current wave.
///
Expand Down
17 changes: 17 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveAnyTrue.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call for int values.

// CHECK-LABEL: define {{.*}}test
bool test(bool p1) {
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
// CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.any(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
// CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.any(i1 %{{[a-zA-Z0-9]+}})
// CHECK: ret i1 %[[RET]]
return WaveActiveAnyTrue(p1);
}
21 changes: 21 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveAnyTrue-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

bool test_too_few_arg() {
return __builtin_hlsl_wave_active_any_true();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

bool test_too_many_arg(bool p0) {
return __builtin_hlsl_wave_active_any_true(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

struct Foo
{
int a;
};

bool test_type_check(Foo p0) {
return __builtin_hlsl_wave_active_any_true(p0);
// expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @inbelic is working to remove IntrNoMem on all the wave intrinsics as part of a general intrinsics cleanup task.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, we can keep the IntrNoMem on the dx intrinsics (assuming it correctly describes memory access). We [require|can use] them in certain passes, for instance the scalarizer. Then they are discarded when lowering to the DXIL op, so they don't correlate to the dxil op attributes.

I will audit that they are added correctly, in this case I think IntrNoMem is correct.

def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ let TargetPrefix = "spv" in {
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall IntrNoMem be removed?
(readlane, active_countbits have it, but maybe that's wrong).

AFAIK is IntrNoMem is added, compiler can assume there are no other side effect. But since this depends on other lanes, isn't that equivalent to IntrReadMem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inbelic says its fine to keep here #115902 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to be corrected. My understanding was that IntrConvergent was used to model the cross-lane side-effect, and we can depend on that to ensure no optimizations would cause a bad reordering/folding. So we can then use IntrNoMem to model that there are no other (memory) side-effects.

cc @farzonl @bogner

Copy link
Contributor

@Keenuts Keenuts Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was IntrConvergent prevented optimizations to move this instruction into another control-flow, but could change the order in a given BB/flow and remove dead loads/stores, but from the few tests I did, it seems those are not optimized anyway (even dead-stores).
So for now let's go with IntrConvergent and IntrNoMem, and if we have a repro/issue in the future, we'll know for sure.

def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,14 @@ def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
let stages = [Stages<DXIL1_6, [all_stages]>];
}

def WaveActiveAnyTrue : DXILOp<113, waveAnyTrue> {
let Doc = "returns true if the expression is true in any of the active lanes in the current wave";
let LLVMIntrinsic = int_dx_wave_any;
let arguments = [Int1Ty];
let result = Int1Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
}

def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
let Doc = "returns 1 for the first lane in the wave";
let LLVMIntrinsic = int_dx_wave_is_first_lane;
Expand Down
75 changes: 32 additions & 43 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveOpInst(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, unsigned Opcode) const;

bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectUnmergeValues(MachineInstr &I) const;

bool selectHandleFromBinding(Register &ResVReg, const SPIRVType *ResType,
Expand Down Expand Up @@ -1939,24 +1939,36 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
return Result;
}

bool SPIRVInstructionSelector::selectWaveOpInst(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
unsigned Opcode) const {
MachineBasicBlock &BB = *I.getParent();
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);

auto BMI = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I,
IntTy, TII));

for (unsigned J = 2; J < I.getNumOperands(); J++) {
BMI.addUse(I.getOperand(J).getReg());
}

return BMI.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectWaveActiveCountBits(
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();

SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
Register BallotReg = MRI->createVirtualRegister(GR.getRegClass(BallotType));
bool Result = selectWaveOpInst(BallotReg, BallotType, I,
SPIRV::OpGroupNonUniformBallot);

bool Result =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot))
.addDef(BallotReg)
.addUse(GR.getSPIRVTypeID(BallotType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);

MachineBasicBlock &BB = *I.getParent();
Result &=
BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
Expand All @@ -1970,26 +1982,6 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return Result;
}

bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
assert(I.getNumOperands() == 4);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());
MachineBasicBlock &BB = *I.getParent();

// IntTy is used to define the execution scope, set to 3 to denote a
// cross-lane interaction equivalent to a SPIR-V subgroup.
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
return BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpGroupNonUniformShuffle))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII))
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg());
}

bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -2838,16 +2830,13 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp);
case Intrinsic::spv_wave_active_countbits:
return selectWaveActiveCountBits(ResVReg, ResType, I);
case Intrinsic::spv_wave_is_first_lane: {
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
return BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpGroupNonUniformElect))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
}
case Intrinsic::spv_wave_any:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
case Intrinsic::spv_wave_is_first_lane:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect);
case Intrinsic::spv_wave_readlane:
return selectWaveReadLaneAt(ResVReg, ResType, I);
return selectWaveOpInst(ResVReg, ResType, I,
SPIRV::OpGroupNonUniformShuffle);
case Intrinsic::spv_step:
return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step);
case Intrinsic::spv_radians:
Expand Down
17 changes: 9 additions & 8 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,15 @@ void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
Capability::Int16});

if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
addAvailableCaps({Capability::GroupNonUniform,
Capability::GroupNonUniformVote,
Capability::GroupNonUniformArithmetic,
Capability::GroupNonUniformBallot,
Capability::GroupNonUniformClustered,
Capability::GroupNonUniformShuffle,
Capability::GroupNonUniformShuffleRelative});

if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
Capability::DotProductInput4x8Bit,
Expand Down Expand Up @@ -676,14 +685,6 @@ void RequirementHandler::initAvailableCapabilitiesForOpenCL(
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
addAvailableCaps({Capability::GroupNonUniform,
Capability::GroupNonUniformVote,
Capability::GroupNonUniformArithmetic,
Capability::GroupNonUniformBallot,
Capability::GroupNonUniformClustered,
Capability::GroupNonUniformShuffle,
Capability::GroupNonUniformShuffleRelative});
if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
Capability::SignedZeroInfNanPreserve,
Expand Down
10 changes: 10 additions & 0 deletions llvm/test/CodeGen/DirectX/WaveActiveAnyTrue.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

define noundef i1 @wave_any_simple(i1 noundef %p1) {
entry:
; CHECK: call i1 @dx.op.waveAnyTrue(i32 113, i1 %p1)
%ret = call i1 @llvm.dx.wave.any(i1 %p1)
ret i1 %ret
}

declare i1 @llvm.dx.wave.any(i1)
21 changes: 21 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAnyTrue.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#bool:]] = OpTypeBool
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
; CHECK-DAG: OpCapability GroupNonUniformVote

; CHECK-LABEL: Begin function test_wave_any
define i1 @test_wave_any(i1 %p1) #0 {
entry:
; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
; CHECK: %{{.+}} = OpGroupNonUniformAny %[[#bool]] %[[#scope]] %[[#param]]
%0 = call token @llvm.experimental.convergence.entry()
%ret = call i1 @llvm.spv.wave.any(i1 %p1) [ "convergencectrl"(token %0) ]
ret i1 %ret
}

declare i1 @llvm.spv.wave.any(i1) #0

attributes #0 = { convergent }