Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0c5a6af

Browse files
committedMar 20, 2025·
WIP: Mostly wrote everything.
1 parent 7baee96 commit 0c5a6af

File tree

9 files changed

+90
-31
lines changed

9 files changed

+90
-31
lines changed
 

‎clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19682,11 +19682,16 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1968219682
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
1968319683
}
1968419684
case Builtin::BI__builtin_hlsl_dot2add: {
19685+
llvm::Triple::ArchType Arch = CGM.getTarget().getTriple().getArch();
19686+
if (Arch != llvm::Triple::dxil) {
19687+
llvm_unreachable("Intrinsic dot2add can be executed as a builtin only on dxil");
19688+
}
1968519689
Value *A = EmitScalarExpr(E->getArg(0));
1968619690
Value *B = EmitScalarExpr(E->getArg(1));
1968719691
Value *C = EmitScalarExpr(E->getArg(2));
1968819692

19689-
Intrinsic::ID ID = CGM.getHLSLRuntime().getDot2AddIntrinsic();
19693+
//llvm::Intrinsic::dx_##IntrinsicPostfix
19694+
Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
1969019695
return Builder.CreateIntrinsic(
1969119696
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
1969219697
"hlsl.dot2add");

‎clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ class CGHLSLRuntime {
9999
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
100100
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
101101
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
102-
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot2Add, dot2add)
103102
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
104103
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
105104
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)

‎clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -982,16 +982,6 @@ uint64_t dot(uint64_t3, uint64_t3);
982982
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
983983
uint64_t dot(uint64_t4, uint64_t4);
984984

985-
//===----------------------------------------------------------------------===//
986-
// dot2add builtins
987-
//===----------------------------------------------------------------------===//
988-
989-
/// \fn float dot2add(half2 a, half2 b, float c)
990-
991-
_HLSL_AVAILABILITY(shadermodel, 6.4)
992-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot2add)
993-
float dot2add(half2, half2, float);
994-
995985
//===----------------------------------------------------------------------===//
996986
// dot4add builtins
997987
//===----------------------------------------------------------------------===//

‎clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
4545
return length_vec_impl(X - Y);
4646
}
4747

48+
constexpr float dot2add_impl(half2 a, half2 b, float c) {
49+
#if defined(__DIRECTX__)
50+
return __builtin_hlsl_dot2add(a, b, c);
51+
#else
52+
return dot(a, b) + c;
53+
#endif
54+
}
55+
4856
template <typename T> constexpr T reflect_impl(T I, T N) {
4957
return I - 2 * N * I * N;
5058
}

‎clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ const inline float distance(__detail::HLSL_FIXED_VECTOR<float, N> X,
117117
return __detail::distance_vec_impl(X, Y);
118118
}
119119

120+
//===----------------------------------------------------------------------===//
121+
// dot2add builtins
122+
//===----------------------------------------------------------------------===//
123+
124+
/// \fn float dot2add(half2 a, half2 b, float c)
125+
/// \brief Dot product of 2 vector of type half and add a float scalar value.
126+
127+
_HLSL_AVAILABILITY(shadermodel, 6.4)
128+
const inline float dot2add(half2 a, half2 b, float c) {
129+
return __detail::dot2add_impl(a, b, c);
130+
}
131+
120132
//===----------------------------------------------------------------------===//
121133
// fmod builtins
122134
//===----------------------------------------------------------------------===//

‎clang/lib/Sema/SemaHLSL.cpp

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,7 +1989,7 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
19891989
}
19901990

19911991
// Helper function for CheckHLSLBuiltinFunctionCall
1992-
static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
1992+
static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall, unsigned NumArgs) {
19931993
assert(TheCall->getNumArgs() > 1);
19941994
ExprResult A = TheCall->getArg(0);
19951995

@@ -1999,7 +1999,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
19991999
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
20002000

20012001
bool AllBArgAreVectors = true;
2002-
for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) {
2002+
for (unsigned i = 1; i < NumArgs; ++i) {
20032003
ExprResult B = TheCall->getArg(i);
20042004
QualType ArgTyB = B.get()->getType();
20052005
auto *VecTyB = ArgTyB->getAs<VectorType>();
@@ -2049,6 +2049,10 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
20492049
return false;
20502050
}
20512051

2052+
static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
2053+
return CheckVectorElementCallArgs(S, TheCall, TheCall->getNumArgs());
2054+
}
2055+
20522056
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
20532057
assert(TheCall->getNumArgs() > 1);
20542058
QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -2091,10 +2095,10 @@ static bool CheckArgTypeIsCorrect(
20912095
return false;
20922096
}
20932097

2094-
static bool CheckAllArgTypesAreCorrect(
2095-
Sema *S, CallExpr *TheCall, QualType ExpectedType,
2098+
static bool CheckArgTypesAreCorrect(
2099+
Sema *S, CallExpr *TheCall, unsigned NumArgs, QualType ExpectedType,
20962100
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
2097-
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
2101+
for (unsigned i = 0; i < NumArgs; ++i) {
20982102
Expr *Arg = TheCall->getArg(i);
20992103
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
21002104
return true;
@@ -2103,6 +2107,13 @@ static bool CheckAllArgTypesAreCorrect(
21032107
return false;
21042108
}
21052109

2110+
static bool CheckAllArgTypesAreCorrect(
2111+
Sema *S, CallExpr *TheCall, QualType ExpectedType,
2112+
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
2113+
return CheckArgTypesAreCorrect(S, TheCall, TheCall->getNumArgs(),
2114+
ExpectedType, Check);
2115+
}
2116+
21062117
static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
21072118
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
21082119
return !PassedType->hasFloatingRepresentation();
@@ -2146,15 +2157,17 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
21462157
return true;
21472158
}
21482159

2149-
static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
2160+
static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall,
2161+
unsigned NumArgs, QualType ExpectedType) {
21502162
auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
21512163
if (const auto *VecTy = PassedType->getAs<VectorType>())
21522164
return VecTy->getElementType()->isDoubleType();
21532165
return false;
21542166
};
2155-
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
2156-
checkDoubleVector);
2167+
return CheckArgTypesAreCorrect(S, TheCall, NumArgs,
2168+
ExpectedType, checkDoubleVector);
21572169
}
2170+
21582171
static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
21592172
auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
21602173
return !PassedType->hasIntegerRepresentation() &&
@@ -2468,7 +2481,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
24682481
return true;
24692482
if (SemaRef.BuiltinVectorToScalarMath(TheCall))
24702483
return true;
2471-
if (CheckNoDoubleVectors(&SemaRef, TheCall))
2484+
if (CheckNoDoubleVectors(&SemaRef, TheCall,
2485+
TheCall->getNumArgs(), SemaRef.Context.FloatTy))
2486+
return true;
2487+
break;
2488+
}
2489+
case Builtin::BI__builtin_hlsl_dot2add: {
2490+
if (SemaRef.checkArgCount(TheCall, 3))
2491+
return true;
2492+
if (CheckVectorElementCallArgs(&SemaRef, TheCall, TheCall->getNumArgs() - 1))
2493+
return true;
2494+
if (CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), SemaRef.getASTContext().FloatTy))
2495+
return true;
2496+
if (CheckNoDoubleVectors(&SemaRef, TheCall,
2497+
TheCall->getNumArgs() - 1,
2498+
SemaRef.Context.HalfTy))
24722499
return true;
24732500
break;
24742501
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
bool test_too_few_arg() {
4+
return __builtin_hlsl_dot2add();
5+
// expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
6+
}
7+
8+
bool test_too_many_arg(half2 p1, half2 p2, float p3) {
9+
return __builtin_hlsl_dot2add(p1, p2, p3, p1);
10+
// expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
11+
}

‎llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,6 @@ let TargetPrefix = "spv" in {
8787
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
8888
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
8989
[IntrNoMem, Commutative] >;
90-
91-
def int_spv_dot2add : DefaultAttrsIntrinsic<[llvm_float_ty],
92-
[llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
93-
[IntrNoMem, Commutative]>;
94-
9590
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
9691
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
9792
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;

‎llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@ static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
5555
}
5656

5757
static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
58-
IRBuilder<> &Builder) {
59-
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
60-
unsigned NumOperands = Orig->getNumOperands() - 1;
58+
IRBuilder<> &Builder,
59+
unsigned NumOperands) {
6160
assert(NumOperands > 0);
6261
Value *Arg0 = Orig->getOperand(0);
6362
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
@@ -75,6 +74,12 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
7574
return NewOperands;
7675
}
7776

77+
static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
78+
IRBuilder<> &Builder) {
79+
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
80+
return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
81+
}
82+
/*
7883
static SmallVector<Value *> argVectorFlattenExcludeLastElement(CallInst *Orig,
7984
IRBuilder<> &Builder) {
8085
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
@@ -95,7 +100,7 @@ static SmallVector<Value *> argVectorFlattenExcludeLastElement(CallInst *Orig,
95100
}
96101
return NewOperands;
97102
}
98-
103+
*/
99104
namespace {
100105
class OpLowerer {
101106
Module &M;
@@ -190,7 +195,13 @@ class OpLowerer {
190195
} else if (IsVectorArgExpansion) {
191196
Args = argVectorFlatten(CI, OpBuilder.getIRB());
192197
} else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
193-
unsigned NumOperands = CI->getNumOperands() - 1;
198+
// arg[NumOperands-1] is a pointer and is not needed by our flattening.
199+
// arg[NumOperands-2] also does not need to be flattened because it is a scalar.
200+
unsigned NumOperands = CI->getNumOperands() - 2;
201+
Args.push_back(CI->getArgOperand(NumOperands));
202+
Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
203+
204+
/*unsigned NumOperands = CI->getNumOperands() - 1;
194205
assert(NumOperands > 0);
195206
Value *LastArg = CI->getOperand(NumOperands - 1);
196207
@@ -201,6 +212,7 @@ class OpLowerer {
201212
202213
//Args = populateOperands(LastArg, OpBuilder.getIRB());
203214
Args.append(argVectorFlattenExcludeLastElement(CI, OpBuilder.getIRB()));
215+
*/
204216
} else {
205217
Args.append(CI->arg_begin(), CI->arg_end());
206218
}

0 commit comments

Comments
 (0)
Please sign in to comment.