-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Add normalize builtins and normalize HLSL function to DirectX and SPIR-V backend #102683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-spir-v @llvm/pr-subscribers-clang-codegen Author: Joshua Batista (bob80905) ChangesThis PR adds the normalize intrinsic and an HLSL function that uses it. Used #101256 as a reference, along with #102243 Patch is 25.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102683.diff 14 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..0a874d8638df43 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_normalize"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_rcp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d1af7fde157b64..58689842dbacad 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18586,6 +18586,29 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
nullptr, "hlsl.length");
}
+ case Builtin::BI__builtin_hlsl_normalize: {
+ Value *X = EmitScalarExpr(E->getArg(0));
+
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ "normalize operand must have a float representation");
+
+ // scalar inputs should expect a scalar return type
+ if (!E->getArg(0)->getType()->isVectorType())
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/X->getType()->getScalarType(),
+ CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
+ nullptr, "hlsl.normalize");
+
+ // construct a vector return type for vector inputs
+ auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+ llvm::Type *retType = X->getType()->getScalarType();
+ retType = llvm::VectorType::get(
+ retType, ElementCount::getFixed(XVecTy->getNumElements()));
+
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/retType, CGM.getHLSLRuntime().getNormalizeIntrinsic(),
+ ArrayRef<Value *>{X}, nullptr, "hlsl.normalize");
+ }
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 527e73a0e21fc4..80ca432f4b509c 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -77,6 +77,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..678cdc77f8a71b 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1352,6 +1352,38 @@ double3 min(double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double4 min(double4, double4);
+//===----------------------------------------------------------------------===//
+// normalize builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T normalize(T x)
+/// \brief Returns the normalized unit vector of the specified floating-point
+/// vector. \param x [in] The vector of floats.
+///
+/// Normalize is based on the following formula: x / length(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half normalize(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half2 normalize(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half3 normalize(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half4 normalize(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float normalize(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float2 normalize(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float3 normalize(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float4 normalize(float4);
+
//===----------------------------------------------------------------------===//
// pow builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a9c0c57e88221d..61f68a415a7d6c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_normalize: {
+ if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+ return true;
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+
+ TheCall->setType(ArgTyA);
+ break;
+ }
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/normalize.hlsl b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
new file mode 100644
index 00000000000000..f46a35866f45d9
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
@@ -0,0 +1,73 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.dx.normalize.f16(half
+// NO_HALF: call float @llvm.dx.normalize.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_normalize_half(half p0)
+{
+ return normalize(p0);
+}
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16
+// NO_HALF: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// NATIVE_HALF: ret <2 x half> %hlsl.normalize
+// NO_HALF: ret <2 x float> %hlsl.normalize
+half2 test_normalize_half2(half2 p0)
+{
+ return normalize(p0);
+}
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16
+// NO_HALF: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// NATIVE_HALF: ret <3 x half> %hlsl.normalize
+// NO_HALF: ret <3 x float> %hlsl.normalize
+half3 test_normalize_half3(half3 p0)
+{
+ return normalize(p0);
+}
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16
+// NO_HALF: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// NATIVE_HALF: ret <4 x half> %hlsl.normalize
+// NO_HALF: ret <4 x float> %hlsl.normalize
+half4 test_normalize_half4(half4 p0)
+{
+ return normalize(p0);
+}
+
+// CHECK: define noundef float @
+// CHECK: call float @llvm.dx.normalize.f32(float
+// CHECK: ret float
+float test_normalize_float(float p0)
+{
+ return normalize(p0);
+}
+// CHECK: define noundef <2 x float> @
+// CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// CHECK: ret <2 x float> %hlsl.normalize
+float2 test_normalize_float2(float2 p0)
+{
+ return normalize(p0);
+}
+// CHECK: define noundef <3 x float> @
+// CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// CHECK: ret <3 x float> %hlsl.normalize
+float3 test_normalize_float3(float3 p0)
+{
+ return normalize(p0);
+}
+// CHECK: define noundef <4 x float> @
+// CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// CHECK: ret <4 x float> %hlsl.normalize
+float4 test_length_float4(float4 p0)
+{
+ return normalize(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
new file mode 100644
index 00000000000000..b348297d37eb1c
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+ return __builtin_hlsl_normalize();
+ // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+ return __builtin_hlsl_normalize(p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_normalize_int_to_float_promotion(int p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..904801e6e9e95f 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..1b5e463822749e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+ def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ac85859af8a53e..e80166e0ff0569 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
+ case Intrinsic::dx_normalize:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
@@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}
+static bool expandNormalizeIntrinsic(CallInst *Orig) {
+ Value *X = Orig->getOperand(0);
+ Type *Ty = Orig->getType();
+ Type *EltTy = Ty->getScalarType();
+ IRBuilder<> Builder(Orig->getParent());
+ Builder.SetInsertPoint(Orig);
+
+ auto *XVec = dyn_cast<FixedVectorType>(Ty);
+ if (!XVec) {
+ if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
+ const APFloat &fpVal = constantFP->getValueAPF();
+ if (fpVal.isZero())
+ report_fatal_error(Twine("Invalid input scalar: length is zero"),
+ /* gen_crash_diag=*/false);
+ }
+ Value *Result = Builder.CreateFDiv(X, X);
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+ }
+
+ Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+ unsigned XVecSize = XVec->getNumElements();
+ Value *DotProduct = nullptr;
+ switch (XVecSize) {
+ case 1:
+ report_fatal_error(Twine("Invalid input vector: length is zero"),
+ /* gen_crash_diag=*/false);
+ break;
+ case 2:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
+ break;
+ case 3:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
+ break;
+ case 4:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
+ break;
+ default:
+ report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
+ /* gen_crash_diag=*/false);
+ }
+
+ Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
+ ArrayRef<Value *>{DotProduct},
+ nullptr, "dx.rsqrt");
+
+ // verify that the length is non-zero
+ // (if the reciprocal sqrt of the length is non-zero, then the length is
+ // non-zero)
+ if (auto *constantFP = dyn_cast<ConstantFP>(Multiplicand)) {
+ const APFloat &fpVal = constantFP->getValueAPF();
+ if (fpVal.isZero())
+ report_fatal_error(Twine("Invalid input vector: length is zero"),
+ /* gen_crash_diag=*/false);
+ }
+
+ Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
+ Value *Result = Builder.CreateFMul(X, MultiplicandVec);
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+}
+
static bool expandPowIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
@@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandLerpIntrinsic(Orig);
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
+ case Intrinsic::dx_normalize:
+ return expandNormalizeIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ed786bd33aa05b..6e27b6c12f8335 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -238,6 +238,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -1349,6 +1352,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(GL::Normalize)
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -2080,6 +2100,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectFmix(ResVReg, ResType, I);
case Intrinsic::spv_frac:
return selectFrac(ResVReg, ResType, I);
+ case Intrinsic::spv_normalize:
+ return selectNormalize(ResVReg, ResType, I);
case Intrinsic::spv_rsqrt:
return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll
new file mode 100644
index 00000000000000..8b4a6692e8725f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize.ll
@@ -0,0 +1,118 @@
+; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for normalize are generated for half/float.
+
+declare half @llvm.dx.normalize.f16(half)
+declare <2 x half> @llvm.dx.normalize.v2f16(<2 x half>)
+declare <3 x half> @llvm.dx.normalize.v3f16(<3 x half>)
+declare <4 x half> @llvm.dx.normalize.v4f16(<4 x half>)
+
+declare float @llvm.dx.normalize.f32(float)
+declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
+declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
+declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
+
+define noundef half @test_normalize_half(half noundef %p0) {
+entry:
+ ; CHECK: fdiv half %p0, %p0
+ %hlsl.normalize = call half @llvm.dx.normalize.f16(half %p0)
+ ret half %hlsl.normalize
+}
+
+define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <2 x half> %{{.*}}, i64 0
+ ; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+ ; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+ ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+ ; CHECK: insertelement <2 x half> poison, half %{{.*}}, i64 0
+ ; CHECK: shufflevector <2 x half> %{{.*}}, <2 x half> poison, <2 x i32> zeroinitializer
+ ; CHECK: fmul <2 x half> %{{.*}}, %{{.*}}
+
+ %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
+ ret <2 x half> %hlsl.normalize
+}
+
+define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <3 x half> %{{.*}}, i64 0
+ ; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+ ; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+ ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+ ; CHECK: insertelement <3 x half> poison, half %{{.*}}, i64 0
+ ; CHECK: shufflevector <3 x half> %{{.*}}, <3 x half> poison, <3 x i32> zeroinitializer
+ ; CHECK: fmul <3 x half> %{{.*}}, %{{.*}}
+
+ %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
+ ret <3 x half> %hlsl.normalize
+}
+
+define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <4 x half> %{{.*}}, i64 0
+ ; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+ ; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+ ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+ ; CHECK: insertelement <4 x half> poison, half %{{.*}}, i64 0
+ ; CHECK:...
[truncated]
|
@llvm/pr-subscribers-backend-x86 Author: Joshua Batista (bob80905) ChangesThis PR adds the normalize intrinsic and an HLSL function that uses it. Used #101256 as a reference, along with #102243 Patch is 25.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102683.diff 14 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..0a874d8638df43 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_normalize"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_rcp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d1af7fde157b64..58689842dbacad 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18586,6 +18586,29 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
nullptr, "hlsl.length");
}
+ case Builtin::BI__builtin_hlsl_normalize: {
+ Value *X = EmitScalarExpr(E->getArg(0));
+
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ "normalize operand must have a float representation");
+
+ // scalar inputs should expect a scalar return type
+ if (!E->getArg(0)->getType()->isVectorType())
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/X->getType()->getScalarType(),
+ CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
+ nullptr, "hlsl.normalize");
+
+ // construct a vector return type for vector inputs
+ auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+ llvm::Type *retType = X->getType()->getScalarType();
+ retType = llvm::VectorType::get(
+ retType, ElementCount::getFixed(XVecTy->getNumElements()));
+
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/retType, CGM.getHLSLRuntime().getNormalizeIntrinsic(),
+ ArrayRef<Value *>{X}, nullptr, "hlsl.normalize");
+ }
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 527e73a0e21fc4..80ca432f4b509c 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -77,6 +77,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..678cdc77f8a71b 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1352,6 +1352,38 @@ double3 min(double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double4 min(double4, double4);
+//===----------------------------------------------------------------------===//
+// normalize builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T normalize(T x)
+/// \brief Returns the normalized unit vector of the specified floating-point
+/// vector. \param x [in] The vector of floats.
+///
+/// Normalize is based on the following formula: x / length(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half normalize(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half2 normalize(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half3 normalize(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half4 normalize(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float normalize(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float2 normalize(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float3 normalize(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float4 normalize(float4);
+
//===----------------------------------------------------------------------===//
// pow builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a9c0c57e88221d..61f68a415a7d6c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_normalize: {
+ if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+ return true;
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+
+ TheCall->setType(ArgTyA);
+ break;
+ }
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/normalize.hlsl b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
new file mode 100644
index 00000000000000..f46a35866f45d9
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
@@ -0,0 +1,73 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.dx.normalize.f16(half
+// NO_HALF: call float @llvm.dx.normalize.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_normalize_half(half p0)
+{
+ return normalize(p0);
+}
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16
+// NO_HALF: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// NATIVE_HALF: ret <2 x half> %hlsl.normalize
+// NO_HALF: ret <2 x float> %hlsl.normalize
+half2 test_normalize_half2(half2 p0)
+{
+ return normalize(p0);
+}
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16
+// NO_HALF: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// NATIVE_HALF: ret <3 x half> %hlsl.normalize
+// NO_HALF: ret <3 x float> %hlsl.normalize
+half3 test_normalize_half3(half3 p0)
+{
+ return normalize(p0);
+}
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16
+// NO_HALF: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// NATIVE_HALF: ret <4 x half> %hlsl.normalize
+// NO_HALF: ret <4 x float> %hlsl.normalize
+half4 test_normalize_half4(half4 p0)
+{
+ return normalize(p0);
+}
+
+// CHECK: define noundef float @
+// CHECK: call float @llvm.dx.normalize.f32(float
+// CHECK: ret float
+float test_normalize_float(float p0)
+{
+ return normalize(p0);
+}
+// CHECK: define noundef <2 x float> @
+// CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// CHECK: ret <2 x float> %hlsl.normalize
+float2 test_normalize_float2(float2 p0)
+{
+ return normalize(p0);
+}
+// CHECK: define noundef <3 x float> @
+// CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// CHECK: ret <3 x float> %hlsl.normalize
+float3 test_normalize_float3(float3 p0)
+{
+ return normalize(p0);
+}
+// CHECK: define noundef <4 x float> @
+// CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// CHECK: ret <4 x float> %hlsl.normalize
+float4 test_length_float4(float4 p0)
+{
+ return normalize(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
new file mode 100644
index 00000000000000..b348297d37eb1c
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+ return __builtin_hlsl_normalize();
+ // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+ return __builtin_hlsl_normalize(p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_normalize_int_to_float_promotion(int p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..904801e6e9e95f 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..1b5e463822749e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+ def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ac85859af8a53e..e80166e0ff0569 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
+ case Intrinsic::dx_normalize:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
@@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}
+static bool expandNormalizeIntrinsic(CallInst *Orig) {
+ Value *X = Orig->getOperand(0);
+ Type *Ty = Orig->getType();
+ Type *EltTy = Ty->getScalarType();
+ IRBuilder<> Builder(Orig->getParent());
+ Builder.SetInsertPoint(Orig);
+
+ auto *XVec = dyn_cast<FixedVectorType>(Ty);
+ if (!XVec) {
+ if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
+ const APFloat &fpVal = constantFP->getValueAPF();
+ if (fpVal.isZero())
+ report_fatal_error(Twine("Invalid input scalar: length is zero"),
+ /* gen_crash_diag=*/false);
+ }
+ Value *Result = Builder.CreateFDiv(X, X);
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+ }
+
+ Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+ unsigned XVecSize = XVec->getNumElements();
+ Value *DotProduct = nullptr;
+ switch (XVecSize) {
+ case 1:
+ report_fatal_error(Twine("Invalid input vector: length is zero"),
+ /* gen_crash_diag=*/false);
+ break;
+ case 2:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
+ break;
+ case 3:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
+ break;
+ case 4:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
+ break;
+ default:
+ report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
+ /* gen_crash_diag=*/false);
+ }
+
+ Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
+ ArrayRef<Value *>{DotProduct},
+ nullptr, "dx.rsqrt");
+
+ // verify that the length is non-zero
+ // (if the reciprocal sqrt of the length is non-zero, then the length is
+ // non-zero)
+ if (auto *constantFP = dyn_cast<ConstantFP>(Multiplicand)) {
+ const APFloat &fpVal = constantFP->getValueAPF();
+ if (fpVal.isZero())
+ report_fatal_error(Twine("Invalid input vector: length is zero"),
+ /* gen_crash_diag=*/false);
+ }
+
+ Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
+ Value *Result = Builder.CreateFMul(X, MultiplicandVec);
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+}
+
static bool expandPowIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
@@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandLerpIntrinsic(Orig);
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
+ case Intrinsic::dx_normalize:
+ return expandNormalizeIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ed786bd33aa05b..6e27b6c12f8335 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -238,6 +238,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -1349,6 +1352,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(GL::Normalize)
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -2080,6 +2100,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectFmix(ResVReg, ResType, I);
case Intrinsic::spv_frac:
return selectFrac(ResVReg, ResType, I);
+ case Intrinsic::spv_normalize:
+ return selectNormalize(ResVReg, ResType, I);
case Intrinsic::spv_rsqrt:
return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll
new file mode 100644
index 00000000000000..8b4a6692e8725f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize.ll
@@ -0,0 +1,118 @@
+; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for normalize are generated for half/float.
+
+declare half @llvm.dx.normalize.f16(half)
+declare <2 x half> @llvm.dx.normalize.v2f16(<2 x half>)
+declare <3 x half> @llvm.dx.normalize.v3f16(<3 x half>)
+declare <4 x half> @llvm.dx.normalize.v4f16(<4 x half>)
+
+declare float @llvm.dx.normalize.f32(float)
+declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
+declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
+declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
+
+define noundef half @test_normalize_half(half noundef %p0) {
+entry:
+ ; CHECK: fdiv half %p0, %p0
+ %hlsl.normalize = call half @llvm.dx.normalize.f16(half %p0)
+ ret half %hlsl.normalize
+}
+
+define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <2 x half> %{{.*}}, i64 0
+ ; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+ ; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+ ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+ ; CHECK: insertelement <2 x half> poison, half %{{.*}}, i64 0
+ ; CHECK: shufflevector <2 x half> %{{.*}}, <2 x half> poison, <2 x i32> zeroinitializer
+ ; CHECK: fmul <2 x half> %{{.*}}, %{{.*}}
+
+ %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
+ ret <2 x half> %hlsl.normalize
+}
+
+define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <3 x half> %{{.*}}, i64 0
+ ; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+ ; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+ ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+ ; CHECK: insertelement <3 x half> poison, half %{{.*}}, i64 0
+ ; CHECK: shufflevector <3 x half> %{{.*}}, <3 x half> poison, <3 x i32> zeroinitializer
+ ; CHECK: fmul <3 x half> %{{.*}}, %{{.*}}
+
+ %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
+ ret <3 x half> %hlsl.normalize
+}
+
+define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <4 x half> %{{.*}}, i64 0
+ ; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+ ; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+ ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+ ; CHECK: insertelement <4 x half> poison, half %{{.*}}, i64 0
+ ; CHECK:...
[truncated]
|
@llvm/pr-subscribers-backend-directx Author: Joshua Batista (bob80905) ChangesThis PR adds the normalize intrinsic and an HLSL function that uses it. Used #101256 as a reference, along with #102243 Patch is 25.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102683.diff 14 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..0a874d8638df43 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_normalize"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_rcp"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d1af7fde157b64..58689842dbacad 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18586,6 +18586,29 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
nullptr, "hlsl.length");
}
+ case Builtin::BI__builtin_hlsl_normalize: {
+ Value *X = EmitScalarExpr(E->getArg(0));
+
+ assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ "normalize operand must have a float representation");
+
+ // scalar inputs should expect a scalar return type
+ if (!E->getArg(0)->getType()->isVectorType())
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/X->getType()->getScalarType(),
+ CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
+ nullptr, "hlsl.normalize");
+
+ // construct a vector return type for vector inputs
+ auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+ llvm::Type *retType = X->getType()->getScalarType();
+ retType = llvm::VectorType::get(
+ retType, ElementCount::getFixed(XVecTy->getNumElements()));
+
+ return Builder.CreateIntrinsic(
+ /*ReturnType=*/retType, CGM.getHLSLRuntime().getNormalizeIntrinsic(),
+ ArrayRef<Value *>{X}, nullptr, "hlsl.normalize");
+ }
case Builtin::BI__builtin_hlsl_elementwise_frac: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 527e73a0e21fc4..80ca432f4b509c 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -77,6 +77,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..678cdc77f8a71b 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1352,6 +1352,38 @@ double3 min(double3, double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
double4 min(double4, double4);
+//===----------------------------------------------------------------------===//
+// normalize builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T normalize(T x)
+/// \brief Returns the normalized unit vector of the specified floating-point
+/// vector. \param x [in] The vector of floats.
+///
+/// Normalize is based on the following formula: x / length(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half normalize(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half2 normalize(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half3 normalize(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half4 normalize(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float normalize(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float2 normalize(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float3 normalize(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float4 normalize(float4);
+
//===----------------------------------------------------------------------===//
// pow builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a9c0c57e88221d..61f68a415a7d6c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_normalize: {
+ if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+ return true;
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+
+ ExprResult A = TheCall->getArg(0);
+ QualType ArgTyA = A.get()->getType();
+
+ TheCall->setType(ArgTyA);
+ break;
+ }
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/normalize.hlsl b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
new file mode 100644
index 00000000000000..f46a35866f45d9
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
@@ -0,0 +1,73 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN: --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.dx.normalize.f16(half
+// NO_HALF: call float @llvm.dx.normalize.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_normalize_half(half p0)
+{
+ return normalize(p0);
+}
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16
+// NO_HALF: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// NATIVE_HALF: ret <2 x half> %hlsl.normalize
+// NO_HALF: ret <2 x float> %hlsl.normalize
+half2 test_normalize_half2(half2 p0)
+{
+ return normalize(p0);
+}
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16
+// NO_HALF: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// NATIVE_HALF: ret <3 x half> %hlsl.normalize
+// NO_HALF: ret <3 x float> %hlsl.normalize
+half3 test_normalize_half3(half3 p0)
+{
+ return normalize(p0);
+}
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16
+// NO_HALF: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// NATIVE_HALF: ret <4 x half> %hlsl.normalize
+// NO_HALF: ret <4 x float> %hlsl.normalize
+half4 test_normalize_half4(half4 p0)
+{
+ return normalize(p0);
+}
+
+// CHECK: define noundef float @
+// CHECK: call float @llvm.dx.normalize.f32(float
+// CHECK: ret float
+float test_normalize_float(float p0)
+{
+ return normalize(p0);
+}
+// CHECK: define noundef <2 x float> @
+// CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// CHECK: ret <2 x float> %hlsl.normalize
+float2 test_normalize_float2(float2 p0)
+{
+ return normalize(p0);
+}
+// CHECK: define noundef <3 x float> @
+// CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// CHECK: ret <3 x float> %hlsl.normalize
+float3 test_normalize_float3(float3 p0)
+{
+ return normalize(p0);
+}
+// CHECK: define noundef <4 x float> @
+// CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// CHECK: ret <4 x float> %hlsl.normalize
+float4 test_length_float4(float4 p0)
+{
+ return normalize(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
new file mode 100644
index 00000000000000..b348297d37eb1c
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+ return __builtin_hlsl_normalize();
+ // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+ return __builtin_hlsl_normalize(p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_normalize_int_to_float_promotion(int p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
+{
+ return __builtin_hlsl_normalize(p1);
+ // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..904801e6e9e95f 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..1b5e463822749e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+ def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ac85859af8a53e..e80166e0ff0569 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
+ case Intrinsic::dx_normalize:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return true;
@@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}
+static bool expandNormalizeIntrinsic(CallInst *Orig) {
+ Value *X = Orig->getOperand(0);
+ Type *Ty = Orig->getType();
+ Type *EltTy = Ty->getScalarType();
+ IRBuilder<> Builder(Orig->getParent());
+ Builder.SetInsertPoint(Orig);
+
+ auto *XVec = dyn_cast<FixedVectorType>(Ty);
+ if (!XVec) {
+ if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
+ const APFloat &fpVal = constantFP->getValueAPF();
+ if (fpVal.isZero())
+ report_fatal_error(Twine("Invalid input scalar: length is zero"),
+ /* gen_crash_diag=*/false);
+ }
+ Value *Result = Builder.CreateFDiv(X, X);
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+ }
+
+ Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+ unsigned XVecSize = XVec->getNumElements();
+ Value *DotProduct = nullptr;
+ switch (XVecSize) {
+ case 1:
+ report_fatal_error(Twine("Invalid input vector: length is zero"),
+ /* gen_crash_diag=*/false);
+ break;
+ case 2:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
+ break;
+ case 3:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
+ break;
+ case 4:
+ DotProduct = Builder.CreateIntrinsic(
+ EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
+ break;
+ default:
+ report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
+ /* gen_crash_diag=*/false);
+ }
+
+ Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
+ ArrayRef<Value *>{DotProduct},
+ nullptr, "dx.rsqrt");
+
+ // verify that the length is non-zero
+ // (if the reciprocal sqrt of the length is non-zero, then the length is
+ // non-zero)
+ if (auto *constantFP = dyn_cast<ConstantFP>(Multiplicand)) {
+ const APFloat &fpVal = constantFP->getValueAPF();
+ if (fpVal.isZero())
+ report_fatal_error(Twine("Invalid input vector: length is zero"),
+ /* gen_crash_diag=*/false);
+ }
+
+ Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
+ Value *Result = Builder.CreateFMul(X, MultiplicandVec);
+
+ Orig->replaceAllUsesWith(Result);
+ Orig->eraseFromParent();
+ return true;
+}
+
static bool expandPowIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
@@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandLerpIntrinsic(Orig);
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
+ case Intrinsic::dx_normalize:
+ return expandNormalizeIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ed786bd33aa05b..6e27b6c12f8335 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -238,6 +238,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -1349,6 +1352,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(GL::Normalize)
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -2080,6 +2100,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectFmix(ResVReg, ResType, I);
case Intrinsic::spv_frac:
return selectFrac(ResVReg, ResType, I);
+ case Intrinsic::spv_normalize:
+ return selectNormalize(ResVReg, ResType, I);
case Intrinsic::spv_rsqrt:
return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll
new file mode 100644
index 00000000000000..8b4a6692e8725f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize.ll
@@ -0,0 +1,118 @@
+; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for normalize are generated for half/float.
+
+declare half @llvm.dx.normalize.f16(half)
+declare <2 x half> @llvm.dx.normalize.v2f16(<2 x half>)
+declare <3 x half> @llvm.dx.normalize.v3f16(<3 x half>)
+declare <4 x half> @llvm.dx.normalize.v4f16(<4 x half>)
+
+declare float @llvm.dx.normalize.f32(float)
+declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
+declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
+declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
+
+define noundef half @test_normalize_half(half noundef %p0) {
+entry:
+ ; CHECK: fdiv half %p0, %p0
+ %hlsl.normalize = call half @llvm.dx.normalize.f16(half %p0)
+ ret half %hlsl.normalize
+}
+
+define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <2 x half> %{{.*}}, i64 0
+ ; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+ ; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+ ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+ ; CHECK: insertelement <2 x half> poison, half %{{.*}}, i64 0
+ ; CHECK: shufflevector <2 x half> %{{.*}}, <2 x half> poison, <2 x i32> zeroinitializer
+ ; CHECK: fmul <2 x half> %{{.*}}, %{{.*}}
+
+ %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
+ ret <2 x half> %hlsl.normalize
+}
+
+define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <3 x half> %{{.*}}, i64 0
+ ; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+ ; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+ ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+ ; CHECK: insertelement <3 x half> poison, half %{{.*}}, i64 0
+ ; CHECK: shufflevector <3 x half> %{{.*}}, <3 x half> poison, <3 x i32> zeroinitializer
+ ; CHECK: fmul <3 x half> %{{.*}}, %{{.*}}
+
+ %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
+ ret <3 x half> %hlsl.normalize
+}
+
+define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
+entry:
+ ; CHECK: extractelement <4 x half> %{{.*}}, i64 0
+ ; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+ ; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+ ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+ ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+ ; CHECK: insertelement <4 x half> poison, half %{{.*}}, i64 0
+ ; CHECK:...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
1261609
to
13102f6
Compare
|
||
Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0); | ||
unsigned XVecSize = XVec->getNumElements(); | ||
Value *DotProduct = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For lines 256 to 279 is something that will have to be cleaned up into a helper function. I think we are going to have some code duplication here @pow2clk is moving getDotProductIntrinsic
here as an expandFdot()
ideally we would have one function for this. If we move forward with your PR before Greg's someone is going to need to clean this up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, I'll make the change if Greg's comes in first.
this is looking pretty good will do a second pass tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM on the SPIR-V side, thanks for this addition!
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/2337 Here is the relevant piece of the build log for the reference:
|
This PR adds the normalize intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.
Used #101256 as a reference, along with #102243
Fixes #99139