Skip to content

Add normalize builtins and normalize HLSL function to DirectX and SPIR-V backend #102683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_normalize"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_rcp"];
let Attributes = [NoThrow, Const];
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18584,6 +18584,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
nullptr, "hlsl.length");
}
case Builtin::BI__builtin_hlsl_normalize: {
Value *X = EmitScalarExpr(E->getArg(0));

assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
"normalize operand must have a float representation");

return Builder.CreateIntrinsic(
/*ReturnType=*/X->getType(),
CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
nullptr, "hlsl.normalize");
}
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)

Expand Down
32 changes: 32 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,38 @@ double3 min(double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double4 min(double4, double4);

//===----------------------------------------------------------------------===//
// normalize builtins
//===----------------------------------------------------------------------===//

/// \fn T normalize(T x)
/// \brief Returns the normalized unit vector of the specified floating-point
/// vector. \param x [in] The vector of floats.
///
/// Normalize is based on the following formula: x / length(x).

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
half normalize(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
half2 normalize(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
half3 normalize(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
half4 normalize(half4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
float normalize(float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
float2 normalize(float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
float3 normalize(float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
float4 normalize(float4);

//===----------------------------------------------------------------------===//
// pow builtins
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_normalize: {
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
return true;
if (SemaRef.checkArgCount(TheCall, 1))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
// return type is the same as the input type
TheCall->setType(ArgTyA);
break;
}
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
Expand Down
100 changes: 100 additions & 0 deletions clang/test/CodeGenHLSL/builtins/normalize.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,DXIL_CHECK,DXIL_NATIVE_HALF,NATIVE_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,DXIL_CHECK,NO_HALF,DXIL_NO_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF,SPIR_NATIVE_HALF,SPIR_CHECK
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF,SPIR_NO_HALF,SPIR_CHECK

// DXIL_NATIVE_HALF: define noundef half @
// SPIR_NATIVE_HALF: define spir_func noundef half @
// DXIL_NATIVE_HALF: call half @llvm.dx.normalize.f16(half
// SPIR_NATIVE_HALF: call half @llvm.spv.normalize.f16(half
// DXIL_NO_HALF: call float @llvm.dx.normalize.f32(float
// SPIR_NO_HALF: call float @llvm.spv.normalize.f32(float
// NATIVE_HALF: ret half
// NO_HALF: ret float
half test_normalize_half(half p0)
{
return normalize(p0);
}
// DXIL_NATIVE_HALF: define noundef <2 x half> @
// SPIR_NATIVE_HALF: define spir_func noundef <2 x half> @
// DXIL_NATIVE_HALF: call <2 x half> @llvm.dx.normalize.v2f16(<2 x half>
// SPIR_NATIVE_HALF: call <2 x half> @llvm.spv.normalize.v2f16(<2 x half>
// DXIL_NO_HALF: call <2 x float> @llvm.dx.normalize.v2f32(<2 x float>
// SPIR_NO_HALF: call <2 x float> @llvm.spv.normalize.v2f32(<2 x float>
// NATIVE_HALF: ret <2 x half> %hlsl.normalize
// NO_HALF: ret <2 x float> %hlsl.normalize
half2 test_normalize_half2(half2 p0)
{
return normalize(p0);
}
// DXIL_NATIVE_HALF: define noundef <3 x half> @
// SPIR_NATIVE_HALF: define spir_func noundef <3 x half> @
// DXIL_NATIVE_HALF: call <3 x half> @llvm.dx.normalize.v3f16(<3 x half>
// SPIR_NATIVE_HALF: call <3 x half> @llvm.spv.normalize.v3f16(<3 x half>
// DXIL_NO_HALF: call <3 x float> @llvm.dx.normalize.v3f32(<3 x float>
// SPIR_NO_HALF: call <3 x float> @llvm.spv.normalize.v3f32(<3 x float>
// NATIVE_HALF: ret <3 x half> %hlsl.normalize
// NO_HALF: ret <3 x float> %hlsl.normalize
half3 test_normalize_half3(half3 p0)
{
return normalize(p0);
}
// DXIL_NATIVE_HALF: define noundef <4 x half> @
// SPIR_NATIVE_HALF: define spir_func noundef <4 x half> @
// DXIL_NATIVE_HALF: call <4 x half> @llvm.dx.normalize.v4f16(<4 x half>
// SPIR_NATIVE_HALF: call <4 x half> @llvm.spv.normalize.v4f16(<4 x half>
// DXIL_NO_HALF: call <4 x float> @llvm.dx.normalize.v4f32(<4 x float>
// SPIR_NO_HALF: call <4 x float> @llvm.spv.normalize.v4f32(<4 x float>
// NATIVE_HALF: ret <4 x half> %hlsl.normalize
// NO_HALF: ret <4 x float> %hlsl.normalize
half4 test_normalize_half4(half4 p0)
{
return normalize(p0);
}

// DXIL_CHECK: define noundef float @
// SPIR_CHECK: define spir_func noundef float @
// DXIL_CHECK: call float @llvm.dx.normalize.f32(float
// SPIR_CHECK: call float @llvm.spv.normalize.f32(float
// CHECK: ret float
float test_normalize_float(float p0)
{
return normalize(p0);
}
// DXIL_CHECK: define noundef <2 x float> @
// SPIR_CHECK: define spir_func noundef <2 x float> @
// DXIL_CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
// SPIR_CHECK: %hlsl.normalize = call <2 x float> @llvm.spv.normalize.v2f32(<2 x float>
// CHECK: ret <2 x float> %hlsl.normalize
float2 test_normalize_float2(float2 p0)
{
return normalize(p0);
}
// DXIL_CHECK: define noundef <3 x float> @
// SPIR_CHECK: define spir_func noundef <3 x float> @
// DXIL_CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
// SPIR_CHECK: %hlsl.normalize = call <3 x float> @llvm.spv.normalize.v3f32(<3 x float>
// CHECK: ret <3 x float> %hlsl.normalize
float3 test_normalize_float3(float3 p0)
{
return normalize(p0);
}
// DXIL_CHECK: define noundef <4 x float> @
// SPIR_CHECK: define spir_func noundef <4 x float> @
// DXIL_CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
// SPIR_CHECK: %hlsl.normalize = call <4 x float> @llvm.spv.normalize.v4f32(
// CHECK: ret <4 x float> %hlsl.normalize
float4 test_length_float4(float4 p0)
{
return normalize(p0);
}
31 changes: 31 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected

void test_too_few_arg()
{
return __builtin_hlsl_normalize();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

void test_too_many_arg(float2 p0)
{
return __builtin_hlsl_normalize(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

bool builtin_bool_to_float_type_promotion(bool p1)
{
return __builtin_hlsl_normalize(p1);
// expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
}

bool builtin_normalize_int_to_float_promotion(int p1)
{
return __builtin_hlsl_normalize(p1);
// expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
}

bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
{
return __builtin_hlsl_normalize(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
72 changes: 72 additions & 0 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
case Intrinsic::dx_normalize:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
Expand Down Expand Up @@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}

static bool expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
Type *EltTy = Ty->getScalarType();
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);

auto *XVec = dyn_cast<FixedVectorType>(Ty);
if (!XVec) {
if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
const APFloat &fpVal = constantFP->getValueAPF();
if (fpVal.isZero())
report_fatal_error(Twine("Invalid input scalar: length is zero"),
/* gen_crash_diag=*/false);
}
Value *Result = Builder.CreateFDiv(X, X);

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
}

Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
unsigned XVecSize = XVec->getNumElements();
Value *DotProduct = nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For lines 256 to 279 is something that will have to be cleaned up into a helper function. I think we are going to have some code duplication here @pow2clk is moving getDotProductIntrinsic here as an expandFdot() ideally we would have one function for this. If we move forward with your PR before Greg's someone is going to need to clean this up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, I'll make the change if Greg's comes in first.

// use the dot intrinsic corresponding to the vector size
switch (XVecSize) {
case 1:
report_fatal_error(Twine("Invalid input vector: length is zero"),
/* gen_crash_diag=*/false);
break;
case 2:
DotProduct = Builder.CreateIntrinsic(
EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
break;
case 3:
DotProduct = Builder.CreateIntrinsic(
EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
break;
case 4:
DotProduct = Builder.CreateIntrinsic(
EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
break;
default:
report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
/* gen_crash_diag=*/false);
}

// verify that the length is non-zero
// (if the dot product is non-zero, then the length is non-zero)
if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
const APFloat &fpVal = constantFP->getValueAPF();
if (fpVal.isZero())
report_fatal_error(Twine("Invalid input vector: length is zero"),
/* gen_crash_diag=*/false);
}

Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
ArrayRef<Value *>{DotProduct},
nullptr, "dx.rsqrt");

Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
Value *Result = Builder.CreateFMul(X, MultiplicandVec);

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
}

static bool expandPowIntrinsic(CallInst *Orig) {

Value *X = Orig->getOperand(0);
Expand Down Expand Up @@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandLerpIntrinsic(Orig);
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
case Intrinsic::dx_normalize:
return expandNormalizeIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

Expand Down Expand Up @@ -1349,6 +1352,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {

assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();

return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
.addImm(GL::Normalize)
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down Expand Up @@ -2080,6 +2100,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectFmix(ResVReg, ResType, I);
case Intrinsic::spv_frac:
return selectFrac(ResVReg, ResType, I);
case Intrinsic::spv_normalize:
return selectNormalize(ResVReg, ResType, I);
case Intrinsic::spv_rsqrt:
return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
Expand Down
Loading