Skip to content

Add length builtins and length HLSL function to DirectX Backend #101256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4707,6 +4707,12 @@ def HLSLIsinf : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLLength : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_length"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_lerp"];
let Attributes = [NoThrow, Const];
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18460,6 +18460,20 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(),
ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp");
}
case Builtin::BI__builtin_hlsl_length: {
Value *X = EmitScalarExpr(E->getArg(0));

assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
"length operand must have a float representation");
// if the operand is a scalar, we can use the fabs llvm intrinsic directly
if (!E->getArg(0)->getType()->isVectorType())
return EmitFAbs(*this, X);

return Builder.CreateIntrinsic(
/*ReturnType=*/X->getType()->getScalarType(),
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
nullptr, "hlsl.length");
}
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
Expand Down
32 changes: 32 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,38 @@ float3 lerp(float3, float3, float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_lerp)
float4 lerp(float4, float4, float4);

//===----------------------------------------------------------------------===//
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR title looks like it may be misnamed, since this is also adding the HLSL builtin functions, it is doing more than just adding it to the backend.

// length builtins
//===----------------------------------------------------------------------===//

/// \fn T length(T x)
/// \brief Returns the length of the specified floating-point vector.
/// \param x [in] The vector of floats, or a scalar float.
///
/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + …).

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
half length(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
half length(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
half length(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
half length(half4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
float length(float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
float length(float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
float length(float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
float length(float4);

//===----------------------------------------------------------------------===//
// log builtins
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 18 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,24 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_length: {
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
return true;
if (SemaRef.checkArgCount(TheCall, 1))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
QualType RetTy;

if (auto *VTy = ArgTyA->getAs<VectorType>())
RetTy = VTy->getElementType();
else
RetTy = TheCall->getArg(0)->getType();

TheCall->setType(RetTy);
break;
}
case Builtin::BI__builtin_hlsl_mad: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
Expand Down
73 changes: 73 additions & 0 deletions clang/test/CodeGenHLSL/builtins/length.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF

// NATIVE_HALF: define noundef half @
// NATIVE_HALF: call half @llvm.fabs.f16(half
// NO_HALF: call float @llvm.fabs.f32(float
// NATIVE_HALF: ret half
// NO_HALF: ret float
half test_length_half(half p0)
{
return length(p0);
}
// NATIVE_HALF: define noundef half @
// NATIVE_HALF: %hlsl.length = call half @llvm.dx.length.v2f16
// NO_HALF: %hlsl.length = call float @llvm.dx.length.v2f32(
// NATIVE_HALF: ret half %hlsl.length
// NO_HALF: ret float %hlsl.length
half test_length_half2(half2 p0)
{
return length(p0);
}
// NATIVE_HALF: define noundef half @
// NATIVE_HALF: %hlsl.length = call half @llvm.dx.length.v3f16
// NO_HALF: %hlsl.length = call float @llvm.dx.length.v3f32(
// NATIVE_HALF: ret half %hlsl.length
// NO_HALF: ret float %hlsl.length
half test_length_half3(half3 p0)
{
return length(p0);
}
// NATIVE_HALF: define noundef half @
// NATIVE_HALF: %hlsl.length = call half @llvm.dx.length.v4f16
// NO_HALF: %hlsl.length = call float @llvm.dx.length.v4f32(
// NATIVE_HALF: ret half %hlsl.length
// NO_HALF: ret float %hlsl.length
half test_length_half4(half4 p0)
{
return length(p0);
}

// CHECK: define noundef float @
// CHECK: call float @llvm.fabs.f32(float
// CHECK: ret float
float test_length_float(float p0)
{
return length(p0);
}
// CHECK: define noundef float @
// CHECK: %hlsl.length = call float @llvm.dx.length.v2f32(
// CHECK: ret float %hlsl.length
float test_length_float2(float2 p0)
{
return length(p0);
}
// CHECK: define noundef float @
// CHECK: %hlsl.length = call float @llvm.dx.length.v3f32(
// CHECK: ret float %hlsl.length
float test_length_float3(float3 p0)
{
return length(p0);
}
// CHECK: define noundef float @
// CHECK: %hlsl.length = call float @llvm.dx.length.v4f32(
// CHECK: ret float %hlsl.length
float test_length_float4(float4 p0)
{
return length(p0);
}
31 changes: 31 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/length-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected

void test_too_few_arg()
{
return __builtin_hlsl_length();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

void test_too_many_arg(float2 p0)
{
return __builtin_hlsl_length(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

bool builtin_bool_to_float_type_promotion(bool p1)
{
return __builtin_hlsl_length(p1);
// expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
}

bool builtin_length_int_to_float_promotion(int p1)
{
return __builtin_hlsl_length(p1);
// expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
}

bool2 builtin_length_int2_to_float2_promotion(int2 p1)
{
return __builtin_hlsl_length(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def int_dx_isinf :
def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;

def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,6 @@ let TargetPrefix = "spv" in {
def int_spv_frac : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this change didn't include the SPIRV parts of this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He has to define this because of CGM.getHLSLRuntime().getLengthIntrinsic(). the alternative would be to just emit int_dx_length in CGBuiltin.cpp until there is a spirv implementation. then swap in getLengthIntrinsic

def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
34 changes: 34 additions & 0 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_clamp:
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
Expand Down Expand Up @@ -157,6 +158,37 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
return true;
}

static bool expandLengthIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();

// Though dx.length does work on scalar type, we can optimize it to just emit
// fabs, in CGBuiltin.cpp. We shouldn't see a scalar type here because
// CGBuiltin.cpp should have emitted a fabs call.
Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
auto *XVec = dyn_cast<FixedVectorType>(Ty);
unsigned XVecSize = XVec->getNumElements();
if (!(Ty->isVectorTy() && XVecSize > 1))
report_fatal_error(Twine("Invalid input type for length intrinsic"),
/* gen_crash_diag=*/false);

Value *Sum = Builder.CreateFMul(Elt, Elt);
for (unsigned I = 1; I < XVecSize; I++) {
Elt = Builder.CreateExtractElement(X, I);
Value *Mul = Builder.CreateFMul(Elt, Elt);
Sum = Builder.CreateFAdd(Sum, Mul);
}
Value *Result = Builder.CreateIntrinsic(
EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
}

static bool expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Expand Down Expand Up @@ -280,6 +312,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandClampIntrinsic(Orig, F.getIntrinsicID());
case Intrinsic::dx_lerp:
return expandLerpIntrinsic(Orig);
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
Expand Down
116 changes: 116 additions & 0 deletions llvm/test/CodeGen/DirectX/length.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK

; Make sure dxil operation function calls for length are generated for half/float.

declare half @llvm.fabs.f16(half)
declare half @llvm.dx.length.v2f16(<2 x half>)
declare half @llvm.dx.length.v3f16(<3 x half>)
declare half @llvm.dx.length.v4f16(<4 x half>)

declare float @llvm.fabs.f32(float)
declare float @llvm.dx.length.v2f32(<2 x float>)
declare float @llvm.dx.length.v3f32(<3 x float>)
declare float @llvm.dx.length.v4f32(<4 x float>)

define noundef half @test_length_half2(<2 x half> noundef %p0) {
entry:
; CHECK: extractelement <2 x half> %{{.*}}, i64 0
; CHECK: fmul half %{{.*}}, %{{.*}}
; CHECK: extractelement <2 x half> %{{.*}}, i64 1
; CHECK: fmul half %{{.*}}, %{{.*}}
; CHECK: fadd half %{{.*}}, %{{.*}}
; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})

%hlsl.length = call half @llvm.dx.length.v2f16(<2 x half> %p0)
ret half %hlsl.length
}

define noundef half @test_length_half3(<3 x half> noundef %p0) {
entry:
; CHECK: extractelement <3 x half> %{{.*}}, i64 0
; CHECK: fmul half %{{.*}}, %{{.*}}
; CHECK: extractelement <3 x half> %{{.*}}, i64 1
; CHECK: fmul half %{{.*}}, %{{.*}}
; CHECK: fadd half %{{.*}}, %{{.*}}
; CHECK: extractelement <3 x half> %{{.*}}, i64 2
; CHECK: fmul half %{{.*}}, %{{.*}}
; CHECK: fadd half %{{.*}}, %{{.*}}
; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})

%hlsl.length = call half @llvm.dx.length.v3f16(<3 x half> %p0)
ret half %hlsl.length
}

define noundef half @test_length_half4(<4 x half> noundef %p0) {
entry:
; CHECK: extractelement <4 x half> %{{.*}}, i64 0
; CHECK: fmul half %{{.*}}, %{{.*}}
; CHECK: extractelement <4 x half> %{{.*}}, i64 1
; CHECK: fmul half %{{.*}}, %{{.*}}
; CHECK: fadd half %{{.*}}, %{{.*}}
; CHECK: extractelement <4 x half> %{{.*}}, i64 2
; CHECK: fmul half %{{.*}}, %{{.*}}
; CHECK: fadd half %{{.*}}, %{{.*}}
; CHECK: extractelement <4 x half> %{{.*}}, i64 3
; CHECK: fmul half %{{.*}}, %{{.*}}
; CHECK: fadd half %{{.*}}, %{{.*}}
; EXPCHECK: call half @llvm.sqrt.f16(half %{{.*}})
; DOPCHECK: call half @dx.op.unary.f16(i32 24, half %{{.*}})

%hlsl.length = call half @llvm.dx.length.v4f16(<4 x half> %p0)
ret half %hlsl.length
}

define noundef float @test_length_float2(<2 x float> noundef %p0) {
entry:
; CHECK: extractelement <2 x float> %{{.*}}, i64 0
; CHECK: fmul float %{{.*}}, %{{.*}}
; CHECK: extractelement <2 x float> %{{.*}}, i64 1
; CHECK: fmul float %{{.*}}, %{{.*}}
; CHECK: fadd float %{{.*}}, %{{.*}}
; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})

%hlsl.length = call float @llvm.dx.length.v2f32(<2 x float> %p0)
ret float %hlsl.length
}

define noundef float @test_length_float3(<3 x float> noundef %p0) {
entry:
; CHECK: extractelement <3 x float> %{{.*}}, i64 0
; CHECK: fmul float %{{.*}}, %{{.*}}
; CHECK: extractelement <3 x float> %{{.*}}, i64 1
; CHECK: fmul float %{{.*}}, %{{.*}}
; CHECK: fadd float %{{.*}}, %{{.*}}
; CHECK: extractelement <3 x float> %{{.*}}, i64 2
; CHECK: fmul float %{{.*}}, %{{.*}}
; CHECK: fadd float %{{.*}}, %{{.*}}
; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})

%hlsl.length = call float @llvm.dx.length.v3f32(<3 x float> %p0)
ret float %hlsl.length
}

define noundef float @test_length_float4(<4 x float> noundef %p0) {
entry:
; CHECK: extractelement <4 x float> %{{.*}}, i64 0
; CHECK: fmul float %{{.*}}, %{{.*}}
; CHECK: extractelement <4 x float> %{{.*}}, i64 1
; CHECK: fmul float %{{.*}}, %{{.*}}
; CHECK: fadd float %{{.*}}, %{{.*}}
; CHECK: extractelement <4 x float> %{{.*}}, i64 2
; CHECK: fmul float %{{.*}}, %{{.*}}
; CHECK: fadd float %{{.*}}, %{{.*}}
; CHECK: extractelement <4 x float> %{{.*}}, i64 3
; CHECK: fmul float %{{.*}}, %{{.*}}
; CHECK: fadd float %{{.*}}, %{{.*}}
; EXPCHECK: call float @llvm.sqrt.f32(float %{{.*}})
; DOPCHECK: call float @dx.op.unary.f32(i32 24, float %{{.*}})

%hlsl.length = call float @llvm.dx.length.v4f32(<4 x float> %p0)
ret float %hlsl.length
}
Loading
Loading