Skip to content

Commit 2711415

Browse files
committed
[HLSL] Implement a header only distance intrinsic
Addressing RFC comments, replace LangBuiltin with TargetBuiltin
1 parent f6365a4 commit 2711415

18 files changed

+664
-290
lines changed

clang/include/clang/Basic/Builtins.td

-6
Original file line numberDiff line numberDiff line change
@@ -4865,12 +4865,6 @@ def HLSLIsinf : LangBuiltin<"HLSL_LANG"> {
48654865
let Prototype = "void(...)";
48664866
}
48674867

4868-
def HLSLLength : LangBuiltin<"HLSL_LANG"> {
4869-
let Spellings = ["__builtin_hlsl_length"];
4870-
let Attributes = [NoThrow, Const];
4871-
let Prototype = "void(...)";
4872-
}
4873-
48744868
def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
48754869
let Spellings = ["__builtin_hlsl_lerp"];
48764870
let Attributes = [NoThrow, Const];

clang/include/clang/Basic/BuiltinsSPIRV.td

+6
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,9 @@ def SPIRVDistance : Builtin {
1313
let Attributes = [NoThrow, Const];
1414
let Prototype = "void(...)";
1515
}
16+
17+
def SPIRVLength : Builtin {
18+
let Spellings = ["__builtin_spirv_length"];
19+
let Attributes = [NoThrow, Const];
20+
let Prototype = "void(...)";
21+
}

clang/lib/CodeGen/CGBuiltin.cpp

+10-14
Original file line numberDiff line numberDiff line change
@@ -19332,20 +19332,6 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1933219332
/*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(),
1933319333
ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp");
1933419334
}
19335-
case Builtin::BI__builtin_hlsl_length: {
19336-
Value *X = EmitScalarExpr(E->getArg(0));
19337-
19338-
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
19339-
"length operand must have a float representation");
19340-
// if the operand is a scalar, we can use the fabs llvm intrinsic directly
19341-
if (!E->getArg(0)->getType()->isVectorType())
19342-
return EmitFAbs(*this, X);
19343-
19344-
return Builder.CreateIntrinsic(
19345-
/*ReturnType=*/X->getType()->getScalarType(),
19346-
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
19347-
nullptr, "hlsl.length");
19348-
}
1934919335
case Builtin::BI__builtin_hlsl_normalize: {
1935019336
Value *X = EmitScalarExpr(E->getArg(0));
1935119337

@@ -20498,6 +20484,16 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
2049820484
/*ReturnType=*/X->getType()->getScalarType(), Intrinsic::spv_distance,
2049920485
ArrayRef<Value *>{X, Y}, nullptr, "spv.distance");
2050020486
}
20487+
case SPIRV::BI__builtin_spirv_length: {
20488+
Value *X = EmitScalarExpr(E->getArg(0));
20489+
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
20490+
"length operand must have a float representation");
20491+
assert(E->getArg(0)->getType()->isVectorType() &&
20492+
"length operand must be a vector");
20493+
return Builder.CreateIntrinsic(
20494+
/*ReturnType=*/X->getType()->getScalarType(), Intrinsic::spv_length,
20495+
ArrayRef<Value *>{X}, nullptr, "hlsl.length");
20496+
}
2050120497
}
2050220498
return nullptr;
2050320499
}

clang/lib/CodeGen/CGHLSLRuntime.h

-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ class CGHLSLRuntime {
7777
GENERATE_HLSL_INTRINSIC_FUNCTION(Cross, cross)
7878
GENERATE_HLSL_INTRINSIC_FUNCTION(Degrees, degrees)
7979
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
80-
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
8180
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
8281
GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
8382
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)

clang/lib/Headers/hlsl/hlsl_detail.h

+42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ namespace hlsl {
1313

1414
namespace __detail {
1515

16+
template <typename T, typename U> struct is_same {
17+
static const bool value = false;
18+
};
19+
20+
template <typename T> struct is_same<T, T> {
21+
static const bool value = true;
22+
};
23+
1624
template <bool B, typename T> struct enable_if {};
1725

1826
template <typename T> struct enable_if<true, T> {
@@ -33,6 +41,40 @@ constexpr enable_if_t<sizeof(U) == sizeof(T), U> bit_cast(T F) {
3341
return __builtin_bit_cast(U, F);
3442
}
3543

44+
template <typename T>
45+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
46+
length_impl(T X) {
47+
return __builtin_elementwise_abs(X);
48+
}
49+
50+
template <typename T, int N>
51+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
52+
length_vec_impl(vector<T, N> X) {
53+
#if (__has_builtin(__builtin_spirv_length))
54+
return __builtin_spirv_length(X);
55+
#else
56+
vector<T, N> XSquared = X * X;
57+
T XSquaredSum = __builtin_hlsl_reduce_add(XSquared);
58+
return __builtin_elementwise_sqrt(XSquaredSum);
59+
#endif
60+
}
61+
62+
template <typename T>
63+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
64+
distance_impl(T X, T Y) {
65+
return length_impl(X - Y);
66+
}
67+
68+
template <typename T, int N>
69+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
70+
distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
71+
#if (__has_builtin(__builtin_spirv_distance))
72+
return __builtin_spirv_distance(X, Y);
73+
#else
74+
return length_vec_impl(X - Y);
75+
#endif
76+
}
77+
3678
} // namespace __detail
3779
} // namespace hlsl
3880
#endif //_HLSL_HLSL_DETAILS_H_

clang/lib/Headers/hlsl/hlsl_intrinsics.h

+38-19
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,34 @@ float3 degrees(float3);
871871
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_degrees)
872872
float4 degrees(float4);
873873

874+
//===----------------------------------------------------------------------===//
875+
// distance builtins
876+
//===----------------------------------------------------------------------===//
877+
878+
/// \fn K distance(T X, T Y)
879+
/// \brief Returns a distance scalar between two vectors of \a X and \a Y.
880+
/// \param X The X input value.
881+
/// \param Y The Y input value.
882+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
883+
const inline half distance(half X, half Y) {
884+
return __detail::distance_impl(X, Y);
885+
}
886+
887+
const inline float distance(float X, float Y) {
888+
return __detail::distance_impl(X, Y);
889+
}
890+
891+
template <int N>
892+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
893+
const inline half distance(vector<half, N> X, vector<half, N> Y) {
894+
return __detail::distance_vec_impl(X, Y);
895+
}
896+
897+
template <int N>
898+
const inline float distance(vector<float, N> X, vector<float, N> Y) {
899+
return __detail::distance_vec_impl(X, Y);
900+
}
901+
874902
//===----------------------------------------------------------------------===//
875903
// dot product builtins
876904
//===----------------------------------------------------------------------===//
@@ -1296,28 +1324,19 @@ float4 lerp(float4, float4, float4);
12961324
/// \param x [in] The vector of floats, or a scalar float.
12971325
///
12981326
/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + ...).
1299-
13001327
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1301-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1302-
half length(half);
1303-
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1304-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1305-
half length(half2);
1306-
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1307-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1308-
half length(half3);
1328+
const inline half length(half X) { return __detail::length_impl(X); }
1329+
const inline float length(float X) { return __detail::length_impl(X); }
1330+
1331+
template <int N>
13091332
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1310-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1311-
half length(half4);
1333+
const inline half length(vector<half, N> X) {
1334+
return __detail::length_vec_impl(X);
1335+
}
13121336

1313-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1314-
float length(float);
1315-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1316-
float length(float2);
1317-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1318-
float length(float3);
1319-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1320-
float length(float4);
1337+
template <int N> const inline float length(vector<float, N> X) {
1338+
return __detail::length_vec_impl(X);
1339+
}
13211340

13221341
//===----------------------------------------------------------------------===//
13231342
// log builtins

clang/lib/Sema/SemaHLSL.cpp

+3-18
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "clang/Basic/DiagnosticSema.h"
2424
#include "clang/Basic/LLVM.h"
2525
#include "clang/Basic/SourceLocation.h"
26+
#include "clang/Basic/TargetBuiltins.h"
2627
#include "clang/Basic/TargetInfo.h"
2728
#include "clang/Sema/Initialization.h"
2829
#include "clang/Sema/ParsedAttr.h"
@@ -2100,24 +2101,6 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
21002101
return true;
21012102
break;
21022103
}
2103-
case Builtin::BI__builtin_hlsl_length: {
2104-
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2105-
return true;
2106-
if (SemaRef.checkArgCount(TheCall, 1))
2107-
return true;
2108-
2109-
ExprResult A = TheCall->getArg(0);
2110-
QualType ArgTyA = A.get()->getType();
2111-
QualType RetTy;
2112-
2113-
if (auto *VTy = ArgTyA->getAs<VectorType>())
2114-
RetTy = VTy->getElementType();
2115-
else
2116-
RetTy = TheCall->getArg(0)->getType();
2117-
2118-
TheCall->setType(RetTy);
2119-
break;
2120-
}
21212104
case Builtin::BI__builtin_hlsl_mad: {
21222105
if (SemaRef.checkArgCount(TheCall, 3))
21232106
return true;
@@ -2220,6 +2203,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
22202203
return true;
22212204
break;
22222205
}
2206+
case SPIRV::BI__builtin_spirv_distance:
2207+
case SPIRV::BI__builtin_spirv_length:
22232208
case Builtin::BI__builtin_elementwise_acos:
22242209
case Builtin::BI__builtin_elementwise_asin:
22252210
case Builtin::BI__builtin_elementwise_atan:

clang/lib/Sema/SemaSPIRV.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,24 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
5151
TheCall->setType(RetTy);
5252
break;
5353
}
54+
case SPIRV::BI__builtin_spirv_length: {
55+
if (SemaRef.checkArgCount(TheCall, 1))
56+
return true;
57+
ExprResult A = TheCall->getArg(0);
58+
QualType ArgTyA = A.get()->getType();
59+
auto *VTy = ArgTyA->getAs<VectorType>();
60+
if (VTy == nullptr) {
61+
SemaRef.Diag(A.get()->getBeginLoc(),
62+
diag::err_typecheck_convert_incompatible)
63+
<< ArgTyA
64+
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
65+
<< 0 << 0;
66+
return true;
67+
}
68+
QualType RetTy = VTy->getElementType();
69+
TheCall->setType(RetTy);
70+
break;
71+
}
5472
}
5573
return false;
5674
}

0 commit comments

Comments
 (0)