Skip to content

Add normalize builtins and normalize HLSL function to DirectX and SPIR-V backend #102683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 13, 2024

Conversation

bob80905
Copy link
Contributor

@bob80905 bob80905 commented Aug 9, 2024

This PR adds the normalize intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used #101256 as a reference, along with #102243
Fixes #99139

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. backend:DirectX HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Aug 9, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 9, 2024

@llvm/pr-subscribers-backend-spir-v
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-clang-codegen

Author: Joshua Batista (bob80905)

Changes

This PR adds the normalize intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used #101256 as a reference, along with #102243
Fixes #99139


Patch is 25.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102683.diff

14 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+23)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+32)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+12)
  • (added) clang/test/CodeGenHLSL/builtins/normalize.hlsl (+73)
  • (added) clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl (+31)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+72)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+22)
  • (added) llvm/test/CodeGen/DirectX/normalize.ll (+118)
  • (added) llvm/test/CodeGen/DirectX/normalize_error.ll (+10)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll (+29)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..0a874d8638df43 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_normalize"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_rcp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d1af7fde157b64..58689842dbacad 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18586,6 +18586,29 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
         nullptr, "hlsl.length");
   }
+  case Builtin::BI__builtin_hlsl_normalize: {
+    Value *X = EmitScalarExpr(E->getArg(0));
+
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           "normalize operand must have a float representation");
+
+    // scalar inputs should expect a scalar return type
+    if (!E->getArg(0)->getType()->isVectorType())
+      return Builder.CreateIntrinsic(
+          /*ReturnType=*/X->getType()->getScalarType(),
+          CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
+          nullptr, "hlsl.normalize");
+
+    // construct a vector return type for vector inputs
+    auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+    llvm::Type *retType = X->getType()->getScalarType();
+    retType = llvm::VectorType::get(
+        retType, ElementCount::getFixed(XVecTy->getNumElements()));
+
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/retType, CGM.getHLSLRuntime().getNormalizeIntrinsic(),
+        ArrayRef<Value *>{X}, nullptr, "hlsl.normalize");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 527e73a0e21fc4..80ca432f4b509c 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -77,6 +77,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..678cdc77f8a71b 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1352,6 +1352,38 @@ double3 min(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double4 min(double4, double4);
 
+//===----------------------------------------------------------------------===//
+// normalize builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T normalize(T x)
+/// \brief Returns the normalized unit vector of the specified floating-point
+/// vector. \param x [in] The vector of floats.
+///
+/// Normalize is based on the following formula: x / length(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half normalize(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half2 normalize(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half3 normalize(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half4 normalize(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float normalize(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float2 normalize(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float3 normalize(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float4 normalize(float4);
+
 //===----------------------------------------------------------------------===//
 // pow builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a9c0c57e88221d..61f68a415a7d6c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_normalize: {
+    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+      return true;
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+
+    TheCall->setType(ArgTyA);
+    break;
+  }
   // Note these are llvm builtins that we want to catch invalid intrinsic
   // generation. Normal handling of these builitns will occur elsewhere.
   case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/normalize.hlsl b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
new file mode 100644
index 00000000000000..f46a35866f45d9
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
@@ -0,0 +1,73 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.dx.normalize.f16(half
+// NO_HALF: call float @llvm.dx.normalize.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_normalize_half(half p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16
+// NO_HALF: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// NATIVE_HALF: ret <2 x half> %hlsl.normalize
+// NO_HALF: ret <2 x float> %hlsl.normalize
+half2 test_normalize_half2(half2 p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16
+// NO_HALF: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// NATIVE_HALF: ret <3 x half> %hlsl.normalize
+// NO_HALF: ret <3 x float> %hlsl.normalize
+half3 test_normalize_half3(half3 p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16
+// NO_HALF: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// NATIVE_HALF: ret <4 x half> %hlsl.normalize
+// NO_HALF: ret <4 x float> %hlsl.normalize
+half4 test_normalize_half4(half4 p0)
+{
+	return normalize(p0);
+}
+
+// CHECK: define noundef float @
+// CHECK: call float @llvm.dx.normalize.f32(float
+// CHECK: ret float
+float test_normalize_float(float p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <2 x float> @
+// CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// CHECK: ret <2 x float> %hlsl.normalize
+float2 test_normalize_float2(float2 p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <3 x float> @
+// CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// CHECK: ret <3 x float> %hlsl.normalize
+float3 test_normalize_float3(float3 p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <4 x float> @
+// CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// CHECK: ret <4 x float> %hlsl.normalize
+float4 test_length_float4(float4 p0)
+{
+	return normalize(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
new file mode 100644
index 00000000000000..b348297d37eb1c
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+  return __builtin_hlsl_normalize();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+  return __builtin_hlsl_normalize(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_normalize_int_to_float_promotion(int p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..904801e6e9e95f 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
 def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 def int_dx_rcp  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..1b5e463822749e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
   def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], 
     [IntrNoMem, IntrWillReturn] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+  def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ac85859af8a53e..e80166e0ff0569 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_lerp:
   case Intrinsic::dx_length:
+  case Intrinsic::dx_normalize:
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
     return true;
@@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
 }
 
+static bool expandNormalizeIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  Type *Ty = Orig->getType();
+  Type *EltTy = Ty->getScalarType();
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+
+  auto *XVec = dyn_cast<FixedVectorType>(Ty);
+  if (!XVec) {
+    if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
+      const APFloat &fpVal = constantFP->getValueAPF();
+      if (fpVal.isZero())
+        report_fatal_error(Twine("Invalid input scalar: length is zero"),
+                           /* gen_crash_diag=*/false);
+    }
+    Value *Result = Builder.CreateFDiv(X, X);
+
+    Orig->replaceAllUsesWith(Result);
+    Orig->eraseFromParent();
+    return true;
+  }
+
+  Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+  unsigned XVecSize = XVec->getNumElements();
+  Value *DotProduct = nullptr;
+  switch (XVecSize) {
+  case 1:
+    report_fatal_error(Twine("Invalid input vector: length is zero"),
+                       /* gen_crash_diag=*/false);
+    break;
+  case 2:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
+    break;
+  case 3:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
+    break;
+  case 4:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
+    break;
+  default:
+    report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
+                       /* gen_crash_diag=*/false);
+  }
+
+  Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
+                                                ArrayRef<Value *>{DotProduct},
+                                                nullptr, "dx.rsqrt");
+
+  // verify that the length is non-zero
+  // (if the reciprocal sqrt of the length is non-zero, then the length is
+  // non-zero)
+  if (auto *constantFP = dyn_cast<ConstantFP>(Multiplicand)) {
+    const APFloat &fpVal = constantFP->getValueAPF();
+    if (fpVal.isZero())
+      report_fatal_error(Twine("Invalid input vector: length is zero"),
+                         /* gen_crash_diag=*/false);
+  }
+
+  Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
+  Value *Result = Builder.CreateFMul(X, MultiplicandVec);
+
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
 static bool expandPowIntrinsic(CallInst *Orig) {
 
   Value *X = Orig->getOperand(0);
@@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     return expandLerpIntrinsic(Orig);
   case Intrinsic::dx_length:
     return expandLengthIntrinsic(Orig);
+  case Intrinsic::dx_normalize:
+    return expandNormalizeIntrinsic(Orig);
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
     return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ed786bd33aa05b..6e27b6c12f8335 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -238,6 +238,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectLog10(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
+  bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
+                   MachineInstr &I) const;
+
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
 
@@ -1349,6 +1352,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
+                                               const SPIRVType *ResType,
+                                               MachineInstr &I) const {
+
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::Normalize)
+      .addUse(I.getOperand(2).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
                                            const SPIRVType *ResType,
                                            MachineInstr &I) const {
@@ -2080,6 +2100,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectFmix(ResVReg, ResType, I);
   case Intrinsic::spv_frac:
     return selectFrac(ResVReg, ResType, I);
+  case Intrinsic::spv_normalize:
+    return selectNormalize(ResVReg, ResType, I);
   case Intrinsic::spv_rsqrt:
     return selectRsqrt(ResVReg, ResType, I);
   case Intrinsic::spv_lifetime_start:
diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll
new file mode 100644
index 00000000000000..8b4a6692e8725f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize.ll
@@ -0,0 +1,118 @@
+; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for normalize are generated for half/float.
+
+declare half @llvm.dx.normalize.f16(half)
+declare <2 x half> @llvm.dx.normalize.v2f16(<2 x half>)
+declare <3 x half> @llvm.dx.normalize.v3f16(<3 x half>)
+declare <4 x half> @llvm.dx.normalize.v4f16(<4 x half>)
+
+declare float @llvm.dx.normalize.f32(float)
+declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
+declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
+declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
+
+define noundef half @test_normalize_half(half noundef %p0) {
+entry:
+  ; CHECK: fdiv half %p0, %p0
+  %hlsl.normalize = call half @llvm.dx.normalize.f16(half %p0)
+  ret half %hlsl.normalize
+}
+
+define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <2 x half> %{{.*}}, i64 0
+  ; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <2 x half> poison, half %{{.*}}, i64 0
+  ; CHECK: shufflevector <2 x half> %{{.*}}, <2 x half> poison, <2 x i32> zeroinitializer
+  ; CHECK: fmul <2 x half> %{{.*}}, %{{.*}}  
+
+  %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
+  ret <2 x half> %hlsl.normalize
+}
+
+define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <3 x half> %{{.*}}, i64 0
+  ; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <3 x half> poison, half %{{.*}}, i64 0
+  ; CHECK: shufflevector <3 x half> %{{.*}}, <3 x half> poison, <3 x i32> zeroinitializer
+  ; CHECK: fmul <3 x half> %{{.*}}, %{{.*}}
+
+  %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
+  ret <3 x half> %hlsl.normalize
+}
+
+define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <4 x half> %{{.*}}, i64 0
+  ; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <4 x half> poison, half %{{.*}}, i64 0
+  ; CHECK:...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 9, 2024

@llvm/pr-subscribers-backend-x86

Author: Joshua Batista (bob80905)

Changes

This PR adds the normalize intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used #101256 as a reference, along with #102243
Fixes #99139


Patch is 25.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102683.diff

14 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+23)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+32)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+12)
  • (added) clang/test/CodeGenHLSL/builtins/normalize.hlsl (+73)
  • (added) clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl (+31)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+72)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+22)
  • (added) llvm/test/CodeGen/DirectX/normalize.ll (+118)
  • (added) llvm/test/CodeGen/DirectX/normalize_error.ll (+10)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll (+29)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..0a874d8638df43 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_normalize"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_rcp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d1af7fde157b64..58689842dbacad 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18586,6 +18586,29 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
         nullptr, "hlsl.length");
   }
+  case Builtin::BI__builtin_hlsl_normalize: {
+    Value *X = EmitScalarExpr(E->getArg(0));
+
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           "normalize operand must have a float representation");
+
+    // scalar inputs should expect a scalar return type
+    if (!E->getArg(0)->getType()->isVectorType())
+      return Builder.CreateIntrinsic(
+          /*ReturnType=*/X->getType()->getScalarType(),
+          CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
+          nullptr, "hlsl.normalize");
+
+    // construct a vector return type for vector inputs
+    auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+    llvm::Type *retType = X->getType()->getScalarType();
+    retType = llvm::VectorType::get(
+        retType, ElementCount::getFixed(XVecTy->getNumElements()));
+
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/retType, CGM.getHLSLRuntime().getNormalizeIntrinsic(),
+        ArrayRef<Value *>{X}, nullptr, "hlsl.normalize");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 527e73a0e21fc4..80ca432f4b509c 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -77,6 +77,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..678cdc77f8a71b 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1352,6 +1352,38 @@ double3 min(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double4 min(double4, double4);
 
+//===----------------------------------------------------------------------===//
+// normalize builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T normalize(T x)
+/// \brief Returns the normalized unit vector of the specified floating-point
+/// vector. \param x [in] The vector of floats.
+///
+/// Normalize is based on the following formula: x / length(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half normalize(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half2 normalize(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half3 normalize(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half4 normalize(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float normalize(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float2 normalize(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float3 normalize(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float4 normalize(float4);
+
 //===----------------------------------------------------------------------===//
 // pow builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a9c0c57e88221d..61f68a415a7d6c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_normalize: {
+    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+      return true;
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+
+    TheCall->setType(ArgTyA);
+    break;
+  }
   // Note these are llvm builtins that we want to catch invalid intrinsic
   // generation. Normal handling of these builitns will occur elsewhere.
   case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/normalize.hlsl b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
new file mode 100644
index 00000000000000..f46a35866f45d9
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
@@ -0,0 +1,73 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.dx.normalize.f16(half
+// NO_HALF: call float @llvm.dx.normalize.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_normalize_half(half p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16
+// NO_HALF: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// NATIVE_HALF: ret <2 x half> %hlsl.normalize
+// NO_HALF: ret <2 x float> %hlsl.normalize
+half2 test_normalize_half2(half2 p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16
+// NO_HALF: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// NATIVE_HALF: ret <3 x half> %hlsl.normalize
+// NO_HALF: ret <3 x float> %hlsl.normalize
+half3 test_normalize_half3(half3 p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16
+// NO_HALF: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// NATIVE_HALF: ret <4 x half> %hlsl.normalize
+// NO_HALF: ret <4 x float> %hlsl.normalize
+half4 test_normalize_half4(half4 p0)
+{
+	return normalize(p0);
+}
+
+// CHECK: define noundef float @
+// CHECK: call float @llvm.dx.normalize.f32(float
+// CHECK: ret float
+float test_normalize_float(float p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <2 x float> @
+// CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// CHECK: ret <2 x float> %hlsl.normalize
+float2 test_normalize_float2(float2 p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <3 x float> @
+// CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// CHECK: ret <3 x float> %hlsl.normalize
+float3 test_normalize_float3(float3 p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <4 x float> @
+// CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// CHECK: ret <4 x float> %hlsl.normalize
+float4 test_length_float4(float4 p0)
+{
+	return normalize(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
new file mode 100644
index 00000000000000..b348297d37eb1c
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+  return __builtin_hlsl_normalize();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+  return __builtin_hlsl_normalize(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_normalize_int_to_float_promotion(int p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..904801e6e9e95f 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
 def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 def int_dx_rcp  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..1b5e463822749e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
   def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], 
     [IntrNoMem, IntrWillReturn] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+  def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ac85859af8a53e..e80166e0ff0569 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_lerp:
   case Intrinsic::dx_length:
+  case Intrinsic::dx_normalize:
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
     return true;
@@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
 }
 
+static bool expandNormalizeIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  Type *Ty = Orig->getType();
+  Type *EltTy = Ty->getScalarType();
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+
+  auto *XVec = dyn_cast<FixedVectorType>(Ty);
+  if (!XVec) {
+    if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
+      const APFloat &fpVal = constantFP->getValueAPF();
+      if (fpVal.isZero())
+        report_fatal_error(Twine("Invalid input scalar: length is zero"),
+                           /* gen_crash_diag=*/false);
+    }
+    Value *Result = Builder.CreateFDiv(X, X);
+
+    Orig->replaceAllUsesWith(Result);
+    Orig->eraseFromParent();
+    return true;
+  }
+
+  Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+  unsigned XVecSize = XVec->getNumElements();
+  Value *DotProduct = nullptr;
+  switch (XVecSize) {
+  case 1:
+    report_fatal_error(Twine("Invalid input vector: length is zero"),
+                       /* gen_crash_diag=*/false);
+    break;
+  case 2:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
+    break;
+  case 3:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
+    break;
+  case 4:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
+    break;
+  default:
+    report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
+                       /* gen_crash_diag=*/false);
+  }
+
+  Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
+                                                ArrayRef<Value *>{DotProduct},
+                                                nullptr, "dx.rsqrt");
+
+  // verify that the length is non-zero
+  // (if the reciprocal sqrt of the length is non-zero, then the length is
+  // non-zero)
+  if (auto *constantFP = dyn_cast<ConstantFP>(Multiplicand)) {
+    const APFloat &fpVal = constantFP->getValueAPF();
+    if (fpVal.isZero())
+      report_fatal_error(Twine("Invalid input vector: length is zero"),
+                         /* gen_crash_diag=*/false);
+  }
+
+  Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
+  Value *Result = Builder.CreateFMul(X, MultiplicandVec);
+
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
 static bool expandPowIntrinsic(CallInst *Orig) {
 
   Value *X = Orig->getOperand(0);
@@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     return expandLerpIntrinsic(Orig);
   case Intrinsic::dx_length:
     return expandLengthIntrinsic(Orig);
+  case Intrinsic::dx_normalize:
+    return expandNormalizeIntrinsic(Orig);
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
     return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ed786bd33aa05b..6e27b6c12f8335 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -238,6 +238,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectLog10(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
+  bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
+                   MachineInstr &I) const;
+
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
 
@@ -1349,6 +1352,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
+                                               const SPIRVType *ResType,
+                                               MachineInstr &I) const {
+
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::Normalize)
+      .addUse(I.getOperand(2).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
                                            const SPIRVType *ResType,
                                            MachineInstr &I) const {
@@ -2080,6 +2100,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectFmix(ResVReg, ResType, I);
   case Intrinsic::spv_frac:
     return selectFrac(ResVReg, ResType, I);
+  case Intrinsic::spv_normalize:
+    return selectNormalize(ResVReg, ResType, I);
   case Intrinsic::spv_rsqrt:
     return selectRsqrt(ResVReg, ResType, I);
   case Intrinsic::spv_lifetime_start:
diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll
new file mode 100644
index 00000000000000..8b4a6692e8725f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize.ll
@@ -0,0 +1,118 @@
+; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for normalize are generated for half/float.
+
+declare half @llvm.dx.normalize.f16(half)
+declare <2 x half> @llvm.dx.normalize.v2f16(<2 x half>)
+declare <3 x half> @llvm.dx.normalize.v3f16(<3 x half>)
+declare <4 x half> @llvm.dx.normalize.v4f16(<4 x half>)
+
+declare float @llvm.dx.normalize.f32(float)
+declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
+declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
+declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
+
+define noundef half @test_normalize_half(half noundef %p0) {
+entry:
+  ; CHECK: fdiv half %p0, %p0
+  %hlsl.normalize = call half @llvm.dx.normalize.f16(half %p0)
+  ret half %hlsl.normalize
+}
+
+define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <2 x half> %{{.*}}, i64 0
+  ; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <2 x half> poison, half %{{.*}}, i64 0
+  ; CHECK: shufflevector <2 x half> %{{.*}}, <2 x half> poison, <2 x i32> zeroinitializer
+  ; CHECK: fmul <2 x half> %{{.*}}, %{{.*}}  
+
+  %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
+  ret <2 x half> %hlsl.normalize
+}
+
+define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <3 x half> %{{.*}}, i64 0
+  ; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <3 x half> poison, half %{{.*}}, i64 0
+  ; CHECK: shufflevector <3 x half> %{{.*}}, <3 x half> poison, <3 x i32> zeroinitializer
+  ; CHECK: fmul <3 x half> %{{.*}}, %{{.*}}
+
+  %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
+  ret <3 x half> %hlsl.normalize
+}
+
+define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <4 x half> %{{.*}}, i64 0
+  ; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <4 x half> poison, half %{{.*}}, i64 0
+  ; CHECK:...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 9, 2024

@llvm/pr-subscribers-backend-directx

Author: Joshua Batista (bob80905)

Changes

This PR adds the normalize intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used #101256 as a reference, along with #102243
Fixes #99139


Patch is 25.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102683.diff

14 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+23)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+32)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+12)
  • (added) clang/test/CodeGenHLSL/builtins/normalize.hlsl (+73)
  • (added) clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl (+31)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+72)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+22)
  • (added) llvm/test/CodeGen/DirectX/normalize.ll (+118)
  • (added) llvm/test/CodeGen/DirectX/normalize_error.ll (+10)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll (+29)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..0a874d8638df43 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4725,6 +4725,12 @@ def HLSLMad : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLNormalize : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_normalize"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_rcp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d1af7fde157b64..58689842dbacad 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18586,6 +18586,29 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
         nullptr, "hlsl.length");
   }
+  case Builtin::BI__builtin_hlsl_normalize: {
+    Value *X = EmitScalarExpr(E->getArg(0));
+
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           "normalize operand must have a float representation");
+
+    // scalar inputs should expect a scalar return type
+    if (!E->getArg(0)->getType()->isVectorType())
+      return Builder.CreateIntrinsic(
+          /*ReturnType=*/X->getType()->getScalarType(),
+          CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
+          nullptr, "hlsl.normalize");
+
+    // construct a vector return type for vector inputs
+    auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+    llvm::Type *retType = X->getType()->getScalarType();
+    retType = llvm::VectorType::get(
+        retType, ElementCount::getFixed(XVecTy->getNumElements()));
+
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/retType, CGM.getHLSLRuntime().getNormalizeIntrinsic(),
+        ArrayRef<Value *>{X}, nullptr, "hlsl.normalize");
+  }
   case Builtin::BI__builtin_hlsl_elementwise_frac: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 527e73a0e21fc4..80ca432f4b509c 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -77,6 +77,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..678cdc77f8a71b 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1352,6 +1352,38 @@ double3 min(double3, double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_min)
 double4 min(double4, double4);
 
+//===----------------------------------------------------------------------===//
+// normalize builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T normalize(T x)
+/// \brief Returns the normalized unit vector of the specified floating-point
+/// vector. \param x [in] The vector of floats.
+///
+/// Normalize is based on the following formula: x / length(x).
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half normalize(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half2 normalize(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half3 normalize(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+half4 normalize(half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float normalize(float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float2 normalize(float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float3 normalize(float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_normalize)
+float4 normalize(float4);
+
 //===----------------------------------------------------------------------===//
 // pow builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a9c0c57e88221d..61f68a415a7d6c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1108,6 +1108,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_normalize: {
+    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+      return true;
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+
+    TheCall->setType(ArgTyA);
+    break;
+  }
   // Note these are llvm builtins that we want to catch invalid intrinsic
   // generation. Normal handling of these builitns will occur elsewhere.
   case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/normalize.hlsl b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
new file mode 100644
index 00000000000000..f46a35866f45d9
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/normalize.hlsl
@@ -0,0 +1,73 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// NATIVE_HALF: define noundef half @
+// NATIVE_HALF: call half @llvm.dx.normalize.f16(half
+// NO_HALF: call float @llvm.dx.normalize.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_normalize_half(half p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <2 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16
+// NO_HALF: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// NATIVE_HALF: ret <2 x half> %hlsl.normalize
+// NO_HALF: ret <2 x float> %hlsl.normalize
+half2 test_normalize_half2(half2 p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <3 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16
+// NO_HALF: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// NATIVE_HALF: ret <3 x half> %hlsl.normalize
+// NO_HALF: ret <3 x float> %hlsl.normalize
+half3 test_normalize_half3(half3 p0)
+{
+	return normalize(p0);
+}
+// NATIVE_HALF: define noundef <4 x half> @
+// NATIVE_HALF: %hlsl.normalize = call <4 x half> @llvm.dx.normalize.v4f16
+// NO_HALF: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// NATIVE_HALF: ret <4 x half> %hlsl.normalize
+// NO_HALF: ret <4 x float> %hlsl.normalize
+half4 test_normalize_half4(half4 p0)
+{
+	return normalize(p0);
+}
+
+// CHECK: define noundef float @
+// CHECK: call float @llvm.dx.normalize.f32(float
+// CHECK: ret float
+float test_normalize_float(float p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <2 x float> @
+// CHECK: %hlsl.normalize = call <2 x float> @llvm.dx.normalize.v2f32(
+// CHECK: ret <2 x float> %hlsl.normalize
+float2 test_normalize_float2(float2 p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <3 x float> @
+// CHECK: %hlsl.normalize = call <3 x float> @llvm.dx.normalize.v3f32(
+// CHECK: ret <3 x float> %hlsl.normalize
+float3 test_normalize_float3(float3 p0)
+{
+	return normalize(p0);
+}
+// CHECK: define noundef <4 x float> @
+// CHECK: %hlsl.normalize = call <4 x float> @llvm.dx.normalize.v4f32(
+// CHECK: ret <4 x float> %hlsl.normalize
+float4 test_length_float4(float4 p0)
+{
+	return normalize(p0);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
new file mode 100644
index 00000000000000..b348297d37eb1c
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/normalize-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+  return __builtin_hlsl_normalize();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+  return __builtin_hlsl_normalize(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_normalize_int_to_float_promotion(int p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_normalize_int2_to_float2_promotion(int2 p1)
+{
+  return __builtin_hlsl_normalize(p1);
+  // expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..904801e6e9e95f 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -58,6 +58,7 @@ def int_dx_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType
 def int_dx_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
 def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 def int_dx_rcp  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..1b5e463822749e 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -64,5 +64,6 @@ let TargetPrefix = "spv" in {
   def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], 
     [IntrNoMem, IntrWillReturn] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
+  def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ac85859af8a53e..e80166e0ff0569 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -43,6 +43,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_lerp:
   case Intrinsic::dx_length:
+  case Intrinsic::dx_normalize:
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
     return true;
@@ -229,6 +230,75 @@ static bool expandLog10Intrinsic(CallInst *Orig) {
   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
 }
 
+static bool expandNormalizeIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  Type *Ty = Orig->getType();
+  Type *EltTy = Ty->getScalarType();
+  IRBuilder<> Builder(Orig->getParent());
+  Builder.SetInsertPoint(Orig);
+
+  auto *XVec = dyn_cast<FixedVectorType>(Ty);
+  if (!XVec) {
+    if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
+      const APFloat &fpVal = constantFP->getValueAPF();
+      if (fpVal.isZero())
+        report_fatal_error(Twine("Invalid input scalar: length is zero"),
+                           /* gen_crash_diag=*/false);
+    }
+    Value *Result = Builder.CreateFDiv(X, X);
+
+    Orig->replaceAllUsesWith(Result);
+    Orig->eraseFromParent();
+    return true;
+  }
+
+  Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
+  unsigned XVecSize = XVec->getNumElements();
+  Value *DotProduct = nullptr;
+  switch (XVecSize) {
+  case 1:
+    report_fatal_error(Twine("Invalid input vector: length is zero"),
+                       /* gen_crash_diag=*/false);
+    break;
+  case 2:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot2, ArrayRef<Value *>{X, X}, nullptr, "dx.dot2");
+    break;
+  case 3:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot3, ArrayRef<Value *>{X, X}, nullptr, "dx.dot3");
+    break;
+  case 4:
+    DotProduct = Builder.CreateIntrinsic(
+        EltTy, Intrinsic::dx_dot4, ArrayRef<Value *>{X, X}, nullptr, "dx.dot4");
+    break;
+  default:
+    report_fatal_error(Twine("Invalid input vector: vector size is invalid."),
+                       /* gen_crash_diag=*/false);
+  }
+
+  Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
+                                                ArrayRef<Value *>{DotProduct},
+                                                nullptr, "dx.rsqrt");
+
+  // verify that the length is non-zero
+  // (if the reciprocal sqrt of the length is non-zero, then the length is
+  // non-zero)
+  if (auto *constantFP = dyn_cast<ConstantFP>(Multiplicand)) {
+    const APFloat &fpVal = constantFP->getValueAPF();
+    if (fpVal.isZero())
+      report_fatal_error(Twine("Invalid input vector: length is zero"),
+                         /* gen_crash_diag=*/false);
+  }
+
+  Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
+  Value *Result = Builder.CreateFMul(X, MultiplicandVec);
+
+  Orig->replaceAllUsesWith(Result);
+  Orig->eraseFromParent();
+  return true;
+}
+
 static bool expandPowIntrinsic(CallInst *Orig) {
 
   Value *X = Orig->getOperand(0);
@@ -314,6 +384,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     return expandLerpIntrinsic(Orig);
   case Intrinsic::dx_length:
     return expandLengthIntrinsic(Orig);
+  case Intrinsic::dx_normalize:
+    return expandNormalizeIntrinsic(Orig);
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
     return expandIntegerDot(Orig, F.getIntrinsicID());
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ed786bd33aa05b..6e27b6c12f8335 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -238,6 +238,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectLog10(Register ResVReg, const SPIRVType *ResType,
                    MachineInstr &I) const;
 
+  bool selectNormalize(Register ResVReg, const SPIRVType *ResType,
+                   MachineInstr &I) const;
+
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
 
@@ -1349,6 +1352,23 @@ bool SPIRVInstructionSelector::selectFrac(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectNormalize(Register ResVReg,
+                                               const SPIRVType *ResType,
+                                               MachineInstr &I) const {
+
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::Normalize)
+      .addUse(I.getOperand(2).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
                                            const SPIRVType *ResType,
                                            MachineInstr &I) const {
@@ -2080,6 +2100,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectFmix(ResVReg, ResType, I);
   case Intrinsic::spv_frac:
     return selectFrac(ResVReg, ResType, I);
+  case Intrinsic::spv_normalize:
+    return selectNormalize(ResVReg, ResType, I);
   case Intrinsic::spv_rsqrt:
     return selectRsqrt(ResVReg, ResType, I);
   case Intrinsic::spv_lifetime_start:
diff --git a/llvm/test/CodeGen/DirectX/normalize.ll b/llvm/test/CodeGen/DirectX/normalize.ll
new file mode 100644
index 00000000000000..8b4a6692e8725f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/normalize.ll
@@ -0,0 +1,118 @@
+; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+
+; Make sure dxil operation function calls for normalize are generated for half/float.
+
+declare half @llvm.dx.normalize.f16(half)
+declare <2 x half> @llvm.dx.normalize.v2f16(<2 x half>)
+declare <3 x half> @llvm.dx.normalize.v3f16(<3 x half>)
+declare <4 x half> @llvm.dx.normalize.v4f16(<4 x half>)
+
+declare float @llvm.dx.normalize.f32(float)
+declare <2 x float> @llvm.dx.normalize.v2f32(<2 x float>)
+declare <3 x float> @llvm.dx.normalize.v3f32(<3 x float>)
+declare <4 x float> @llvm.dx.normalize.v4f32(<4 x float>)
+
+define noundef half @test_normalize_half(half noundef %p0) {
+entry:
+  ; CHECK: fdiv half %p0, %p0
+  %hlsl.normalize = call half @llvm.dx.normalize.f16(half %p0)
+  ret half %hlsl.normalize
+}
+
+define noundef <2 x half> @test_normalize_half2(<2 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <2 x half> %{{.*}}, i64 0
+  ; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <2 x half> poison, half %{{.*}}, i64 0
+  ; CHECK: shufflevector <2 x half> %{{.*}}, <2 x half> poison, <2 x i32> zeroinitializer
+  ; CHECK: fmul <2 x half> %{{.*}}, %{{.*}}  
+
+  %hlsl.normalize = call <2 x half> @llvm.dx.normalize.v2f16(<2 x half> %p0)
+  ret <2 x half> %hlsl.normalize
+}
+
+define noundef <3 x half> @test_normalize_half3(<3 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <3 x half> %{{.*}}, i64 0
+  ; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <3 x half> poison, half %{{.*}}, i64 0
+  ; CHECK: shufflevector <3 x half> %{{.*}}, <3 x half> poison, <3 x i32> zeroinitializer
+  ; CHECK: fmul <3 x half> %{{.*}}, %{{.*}}
+
+  %hlsl.normalize = call <3 x half> @llvm.dx.normalize.v3f16(<3 x half> %p0)
+  ret <3 x half> %hlsl.normalize
+}
+
+define noundef <4 x half> @test_normalize_half4(<4 x half> noundef %p0) {
+entry:
+  ; CHECK: extractelement <4 x half> %{{.*}}, i64 0
+  ; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+  ; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
+  ; EXPCHECK: call half @llvm.dx.rsqrt.f16(half %{{.*}})
+  ; DOPCHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
+  ; CHECK: insertelement <4 x half> poison, half %{{.*}}, i64 0
+  ; CHECK:...
[truncated]

Copy link

github-actions bot commented Aug 9, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.


Value *Elt = Builder.CreateExtractElement(X, (uint64_t)0);
unsigned XVecSize = XVec->getNumElements();
Value *DotProduct = nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For lines 256 to 279 is something that will have to be cleaned up into a helper function. I think we are going to have some code duplication here @pow2clk is moving getDotProductIntrinsic here as an expandFdot() ideally we would have one function for this. If we move forward with your PR before Greg's someone is going to need to clean this up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, I'll make the change if Greg's comes in first.

@farzonl
Copy link
Member

farzonl commented Aug 12, 2024

this is looking pretty good will do a second pass tomorrow.

Copy link
Contributor

@Keenuts Keenuts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on the SPIR-V side, thanks for this addition!

@bob80905 bob80905 merged commit 1b2d11d into llvm:main Aug 13, 2024
9 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Aug 14, 2024

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building clang,llvm at step 5 "build-check-mlir-build-only".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/2337

Here is the relevant piece of the build log for the reference:

Step 5 (build-check-mlir-build-only) failure: build (failure)
...
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/examples/transform/Ch4/lib/MyExtension.cpp:66:31: warning: suggest parentheses around ‘&&’ within ‘||’ [-Wparentheses]
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/examples/transform/Ch4/lib/MyExtension.cpp:66:31: warning: suggest parentheses around ‘&&’ within ‘||’ [-Wparentheses]
671.336 [23/16/4474] Linking CXX static library lib/libMyExtensionCh4.a
682.402 [22/16/4475] Building CXX object tools/mlir/test/lib/Dialect/Test/CMakeFiles/MLIRTestFromLLVMIRTranslation.dir/TestFromLLVMIRTranslation.cpp.o
682.481 [21/16/4476] Linking CXX static library lib/libMLIRTestFromLLVMIRTranslation.a
683.891 [20/16/4477] Building CXX object tools/mlir/test/lib/Dialect/Test/CMakeFiles/MLIRTestToLLVMIRTranslation.dir/TestToLLVMIRTranslation.cpp.o
684.013 [19/16/4478] Linking CXX static library lib/libMLIRTestToLLVMIRTranslation.a
685.653 [18/16/4479] Linking CXX executable bin/mlir-transform-opt
685.727 [17/16/4480] Linking CXX executable bin/mlir-translate
698.042 [16/16/4481] Building CXX object tools/llc/CMakeFiles/llc.dir/llc.cpp.o
command timed out: 1200 seconds without output running [b'ninja', b'-j', b'16', b'check-mlir-build-only'], attempting to kill
process killed by signal 9
program finished with exit code -1
elapsedTime=4901.860113

bob80905 added a commit that referenced this pull request Sep 12, 2024
#106471)

This PR adds the step intrinsic and an HLSL function that uses it.
The SPIRV backend is also implemented.

Used #102683 as a reference.
Fixes #99157
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX backend:SPIR-V backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

Implement the normalize HLSL Function
7 participants