Skip to content

[HLSL] Implement dot2add intrinsic #131237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 3, 2025
Merged
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4891,6 +4891,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLDot2Add : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot2add"];
let Attributes = [NoThrow, Const];
let Prototype = "float(_ExtVector<2, _Float16>, _ExtVector<2, _Float16>, float)";
}

def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_dot4add_i8packed"];
let Attributes = [NoThrow, Const];
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,19 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
}
case Builtin::BI__builtin_hlsl_dot2add: {
llvm::Triple::ArchType Arch = CGM.getTarget().getTriple().getArch();
assert(Arch == llvm::Triple::dxil &&
"Intrinsic dot2add is only allowed for dxil architecture");
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
Value *C = EmitScalarExpr(E->getArg(2));

Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
return Builder.CreateIntrinsic(
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
"dx.dot2add");
}
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
return length_vec_impl(X - Y);
}

constexpr float dot2add_impl(half2 a, half2 b, float c) {
#if defined(__DIRECTX__)
return __builtin_hlsl_dot2add(a, b, c);
#else
return dot(a, b) + c;
#endif
}

template <typename T> constexpr T reflect_impl(T I, T N) {
return I - 2 * N * I * N;
}
Expand Down
15 changes: 15 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,21 @@ const inline float distance(__detail::HLSL_FIXED_VECTOR<float, N> X,
return __detail::distance_vec_impl(X, Y);
}

//===----------------------------------------------------------------------===//
// dot2add builtins
//===----------------------------------------------------------------------===//

/// \fn float dot2add(half2 A, half2 B, float C)
/// \brief Dot product of 2 vector of type half and add a float scalar value.
/// \param A The first input value to dot product.
/// \param B The second input value to dot product.
/// \param C The input value added to the dot product.

_HLSL_AVAILABILITY(shadermodel, 6.4)
const inline float dot2add(half2 A, half2 B, float C) {
return __detail::dot2add_impl(A, B, C);
}

//===----------------------------------------------------------------------===//
// fmod builtins
//===----------------------------------------------------------------------===//
Expand Down
135 changes: 135 additions & 0 deletions clang/test/CodeGenHLSL/builtins/dot2add.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call.

// CHECK-LABEL: define {{.*}}test_default_parameter_type
float test_default_parameter_type(half2 p1, half2 p2, float p3) {
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}

// CHECK-LABEL: define {{.*}}test_float_arg2_type
float test_float_arg2_type(half2 p1, float2 p2, float p3) {
// CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}} to <2 x half>
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}

// CHECK-LABEL: define {{.*}}test_float_arg1_type
float test_float_arg1_type(float2 p1, half2 p2, float p3) {
// CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}} to <2 x half>
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}

// CHECK-LABEL: define {{.*}}test_double_arg3_type
float test_double_arg3_type(half2 p1, half2 p2, double p3) {
// CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn double %{{.*}} to float
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}

// CHECK-LABEL: define {{.*}}test_float_arg1_arg2_type
float test_float_arg1_arg2_type(float2 p1, float2 p2, float p3) {
// CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}} to <2 x half>
// CHECK: %conv1 = fptrunc reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}} to <2 x half>
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}

// CHECK-LABEL: define {{.*}}test_double_arg1_arg2_type
float test_double_arg1_arg2_type(double2 p1, double2 p2, float p3) {
// CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn <2 x double> %{{.*}} to <2 x half>
// CHECK: %conv1 = fptrunc reassoc nnan ninf nsz arcp afn <2 x double> %{{.*}} to <2 x half>
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}

// CHECK-LABEL: define {{.*}}test_int16_arg1_arg2_type
float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, float p3) {
// CHECK: %conv = sitofp <2 x i16> %{{.*}} to <2 x half>
// CHECK: %conv1 = sitofp <2 x i16> %{{.*}} to <2 x half>
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}

// CHECK-LABEL: define {{.*}}test_int32_arg1_arg2_type
float test_int32_arg1_arg2_type(int32_t2 p1, int32_t2 p2, float p3) {
// CHECK: %conv = sitofp <2 x i32> %{{.*}} to <2 x half>
// CHECK: %conv1 = sitofp <2 x i32> %{{.*}} to <2 x half>
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}

// CHECK-LABEL: define {{.*}}test_int64_arg1_arg2_type
float test_int64_arg1_arg2_type(int64_t2 p1, int64_t2 p2, float p3) {
// CHECK: %conv = sitofp <2 x i64> %{{.*}} to <2 x half>
// CHECK: %conv1 = sitofp <2 x i64> %{{.*}} to <2 x half>
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}

// CHECK-LABEL: define {{.*}}test_bool_arg1_arg2_type
float test_bool_arg1_arg2_type(bool2 p1, bool2 p2, float p3) {
// CHECK: %loadedv = trunc <2 x i32> %{{.*}} to <2 x i1>
// CHECK: %conv = uitofp <2 x i1> %loadedv to <2 x half>
// CHECK: %loadedv1 = trunc <2 x i32> %{{.*}} to <2 x i1>
// CHECK: %conv2 = uitofp <2 x i1> %loadedv1 to <2 x half>
// CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
13 changes: 13 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/dot2add-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify

float test_too_few_arg() {
return dot2add();
// expected-error@-1 {{no matching function for call to 'dot2add'}}
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 0 were provided}}
}

float test_too_many_arg(half2 p1, half2 p2, float p3) {
return dot2add(p1, p2, p3, p1);
// expected-error@-1 {{no matching function for call to 'dot2add'}}
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 4 were provided}}
}
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def int_dx_udot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_dx_dot2add :
DefaultAttrsIntrinsic<[llvm_float_ty],
[llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
[IntrNoMem, Commutative]>;
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;

Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,17 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
let stages = [Stages<DXIL1_2, [all_stages]>];
}

def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
let Doc = "dot product of 2 vectors of half having size = 2, returns "
"float";
let intrinsics = [IntrinSelect<int_dx_dot2add>];
let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
let result = FloatTy;
let overloads = [Overloads<DXIL1_0, []>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
let Doc = "signed dot product of 4 x i8 vectors packed into i32, with "
"accumulate to i32";
Expand Down
19 changes: 15 additions & 4 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,8 @@ static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
return ExtractedElements;
}

static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
IRBuilder<> &Builder) {
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
unsigned NumOperands = Orig->getNumOperands() - 1;
static SmallVector<Value *>
argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder, unsigned NumOperands) {
assert(NumOperands > 0);
Value *Arg0 = Orig->getOperand(0);
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
Expand All @@ -75,6 +73,12 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
return NewOperands;
}

static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
IRBuilder<> &Builder) {
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
}

namespace {
class OpLowerer {
Module &M;
Expand Down Expand Up @@ -168,6 +172,13 @@ class OpLowerer {
}
} else if (IsVectorArgExpansion) {
Args = argVectorFlatten(CI, OpBuilder.getIRB());
} else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
// arg[NumOperands-1] is a pointer and is not needed by our flattening.
// arg[NumOperands-2] also does not need to be flattened because it is a
// scalar.
unsigned NumOperands = CI->getNumOperands() - 2;
Args.push_back(CI->getArgOperand(NumOperands));
Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
} else {
Args.append(CI->arg_begin(), CI->arg_end());
}
Expand Down
8 changes: 8 additions & 0 deletions llvm/test/CodeGen/DirectX/dot2add.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %c) {
entry:
; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3)
%ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c)
ret float %ret
}