diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp index 2c11373504e8c..aaf994b23cf3c 100644 --- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp +++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp @@ -78,13 +78,13 @@ bool DXContainerGlobals::runOnModule(Module &M) { } GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) { - const uint64_t FeatureFlags = - static_cast(getAnalysis() - .getShaderFlags() - .getFeatureFlags()); + uint64_t CombinedFeatureFlags = getAnalysis() + .getShaderFlags() + .getCombinedFlags() + .getFeatureFlags(); Constant *FeatureFlagsConstant = - ConstantInt::get(M.getContext(), APInt(64, FeatureFlags)); + ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags)); return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0"); } diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index 9fa137b4c025e..d6917dce98abd 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -13,36 +13,54 @@ #include "DXILShaderFlags.h" #include "DirectX.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Module.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" using namespace llvm; using namespace llvm::dxil; -static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) { - Type *Ty = I.getType(); - if (Ty->isDoubleTy()) { - Flags.Doubles = true; +static void updateFunctionFlags(ComputedShaderFlags &CSF, + const Instruction &I) { + if (!CSF.Doubles) + CSF.Doubles = I.getType()->isDoubleTy(); + + if (!CSF.Doubles) { + for (Value *Op : I.operands()) + CSF.Doubles |= Op->getType()->isDoubleTy(); + } + if (CSF.Doubles) { switch (I.getOpcode()) { case Instruction::FDiv: case Instruction::UIToFP: case Instruction::SIToFP: case Instruction::FPToUI: case Instruction::FPToSI: - Flags.DX11_1_DoubleExtensions = true; + // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma + // https://github.com/llvm/llvm-project/issues/114554 + CSF.DX11_1_DoubleExtensions = true; break; } } } -ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) { - ComputedShaderFlags Flags; - for (const auto &F : M) +void ModuleShaderFlags::initialize(const Module &M) { + // Collect shader flags for each of the functions + for (const auto &F : M.getFunctionList()) { + if (F.isDeclaration()) + continue; + ComputedShaderFlags CSF; for (const auto &BB : F) for (const auto &I : BB) - updateFlags(Flags, I); - return Flags; + updateFunctionFlags(CSF, I); + // Insert shader flag mask for function F + FunctionFlags.push_back({&F, CSF}); + // Update combined shader flags mask + CombinedSFMask.merge(CSF); + } + llvm::sort(FunctionFlags); } void ComputedShaderFlags::print(raw_ostream &OS) const { @@ -63,20 +81,58 @@ void ComputedShaderFlags::print(raw_ostream &OS) const { OS << ";\n"; } +/// Return the shader flags mask of the specified function Func. +const ComputedShaderFlags & +ModuleShaderFlags::getFunctionFlags(const Function *Func) const { + const auto Iter = llvm::lower_bound( + FunctionFlags, Func, + [](const std::pair FSM, + const Function *FindFunc) { return (FSM.first < FindFunc); }); + assert((Iter != FunctionFlags.end() && Iter->first == Func) && + "No Shader Flags Mask exists for function"); + return Iter->second; +} + +//===----------------------------------------------------------------------===// +// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass + +// Provide an explicit template instantiation for the static ID. AnalysisKey ShaderFlagsAnalysis::Key; -ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M, - ModuleAnalysisManager &AM) { - return ComputedShaderFlags::computeFlags(M); +ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M, + ModuleAnalysisManager &AM) { + ModuleShaderFlags MSFI; + MSFI.initialize(M); + return MSFI; } PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, ModuleAnalysisManager &AM) { - ComputedShaderFlags Flags = AM.getResult(M); - Flags.print(OS); + const ModuleShaderFlags &FlagsInfo = AM.getResult(M); + // Print description of combined shader flags for all module functions + OS << "; Combined Shader Flags for Module\n"; + FlagsInfo.getCombinedFlags().print(OS); + // Print shader flags mask for each of the module functions + OS << "; Shader Flags for Module Functions\n"; + for (const auto &F : M.getFunctionList()) { + if (F.isDeclaration()) + continue; + auto SFMask = FlagsInfo.getFunctionFlags(&F); + OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(), + (uint64_t)(SFMask)); + } + return PreservedAnalyses::all(); } +//===----------------------------------------------------------------------===// +// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass + +bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) { + MSFI.initialize(M); + return false; +} + char ShaderFlagsAnalysisWrapper::ID = 0; INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h index 1df7d27de13d3..2d60137f8b191 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.h +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h @@ -14,12 +14,14 @@ #ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H #define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H +#include "llvm/IR/Function.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include +#include namespace llvm { class Module; @@ -43,15 +45,23 @@ struct ComputedShaderFlags { constexpr uint64_t getMask(int Bit) const { return Bit != -1 ? 1ull << Bit : 0; } + + uint64_t getModuleFlags() const { + uint64_t ModuleFlags = 0; +#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \ + ModuleFlags |= FlagName ? getMask(DxilModuleBit) : 0ull; +#include "llvm/BinaryFormat/DXContainerConstants.def" + return ModuleFlags; + } + operator uint64_t() const { - uint64_t FlagValue = 0; + uint64_t FlagValue = getModuleFlags(); #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \ FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull; -#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \ - FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull; #include "llvm/BinaryFormat/DXContainerConstants.def" return FlagValue; } + uint64_t getFeatureFlags() const { uint64_t FeatureFlags = 0; #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \ @@ -60,11 +70,33 @@ struct ComputedShaderFlags { return FeatureFlags; } - static ComputedShaderFlags computeFlags(Module &M); + void merge(const uint64_t IVal) { +#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \ + FlagName |= (IVal & getMask(DxilModuleBit)); +#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \ + FlagName |= (IVal & getMask(DxilModuleBit)); +#include "llvm/BinaryFormat/DXContainerConstants.def" + return; + } + void print(raw_ostream &OS = dbgs()) const; LLVM_DUMP_METHOD void dump() const { print(); } }; +struct ModuleShaderFlags { + void initialize(const Module &); + const ComputedShaderFlags &getFunctionFlags(const Function *) const; + const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; } + +private: + /// Vector of sorted Function-Shader Flag mask pairs representing properties + /// of each of the functions in the module. Shader Flags of each function + /// represent both module-level and function-level flags + SmallVector> FunctionFlags; + /// Combined Shader Flag Mask of all functions of the module + ComputedShaderFlags CombinedSFMask{}; +}; + class ShaderFlagsAnalysis : public AnalysisInfoMixin { friend AnalysisInfoMixin; static AnalysisKey Key; @@ -72,9 +104,9 @@ class ShaderFlagsAnalysis : public AnalysisInfoMixin { public: ShaderFlagsAnalysis() = default; - using Result = ComputedShaderFlags; + using Result = ModuleShaderFlags; - ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM); + ModuleShaderFlags run(Module &M, ModuleAnalysisManager &AM); }; /// Printer pass for ShaderFlagsAnalysis results. @@ -92,19 +124,16 @@ class ShaderFlagsAnalysisPrinter /// This is required because the passes that will depend on this are codegen /// passes which run through the legacy pass manager. class ShaderFlagsAnalysisWrapper : public ModulePass { - ComputedShaderFlags Flags; + ModuleShaderFlags MSFI; public: static char ID; ShaderFlagsAnalysisWrapper() : ModulePass(ID) {} - const ComputedShaderFlags &getShaderFlags() { return Flags; } + const ModuleShaderFlags &getShaderFlags() { return MSFI; } - bool runOnModule(Module &M) override { - Flags = ComputedShaderFlags::computeFlags(M); - return false; - } + bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index be370e10df694..4ba10d123e8d2 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -286,11 +286,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, MDTuple *Properties = nullptr; if (ShaderFlags != 0) { SmallVector MDVals; - // FIXME: ShaderFlagsAnalysis pass needs to collect and provide - // ShaderFlags for each entry function. Currently, ShaderFlags value - // provided by ShaderFlagsAnalysis pass is created by walking *all* the - // function instructions of the module. Is it is correct to use this value - // for metadata of the empty library entry? MDVals.append( getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx)); Properties = MDNode::get(Ctx, MDVals); @@ -302,7 +297,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, static void translateMetadata(Module &M, const DXILResourceMap &DRM, const Resources &MDResources, - const ComputedShaderFlags &ShaderFlags, + const ModuleShaderFlags &ShaderFlags, const ModuleMetadataInfo &MMDI) { LLVMContext &Ctx = M.getContext(); IRBuilder<> IRB(Ctx); @@ -318,23 +313,27 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM, // See https://github.com/llvm/llvm-project/issues/57928 MDTuple *Signatures = nullptr; - if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) + if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) { + // Get the combined shader flag mask of all functions in the library to be + // used as shader flags mask value associated with top-level library entry + // metadata. + uint64_t CombinedMask = ShaderFlags.getCombinedFlags(); EntryFnMDNodes.emplace_back( - emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags)); - else if (MMDI.EntryPropertyVec.size() > 1) { + emitTopLevelLibraryNode(M, ResourceMD, CombinedMask)); + } else if (MMDI.EntryPropertyVec.size() > 1) { M.getContext().diagnose(DiagnosticInfoTranslateMD( M, "Non-library shader: One and only one entry expected")); } for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) { - // FIXME: ShaderFlagsAnalysis pass needs to collect and provide - // ShaderFlags for each entry function. For now, assume shader flags value - // of entry functions being compiled for lib_* shader profile viz., - // EntryPro.Entry is 0. - uint64_t EntryShaderFlags = - (MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0 - : ShaderFlags; + const ComputedShaderFlags &EntrySFMask = + ShaderFlags.getFunctionFlags(EntryProp.Entry); + + // If ShaderProfile is Library, mask is already consolidated in the + // top-level library node. Hence it is not emitted. + uint64_t EntryShaderFlags = 0; if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) { + EntryShaderFlags = EntrySFMask; if (EntryProp.ShaderStage != MMDI.ShaderProfile) { M.getContext().diagnose(DiagnosticInfoTranslateMD( M, @@ -361,8 +360,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M, ModuleAnalysisManager &MAM) { const DXILResourceMap &DRM = MAM.getResult(M); const dxil::Resources &MDResources = MAM.getResult(M); - const ComputedShaderFlags &ShaderFlags = - MAM.getResult(M); + const ModuleShaderFlags &ShaderFlags = MAM.getResult(M); const dxil::ModuleMetadataInfo MMDI = MAM.getResult(M); translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI); @@ -393,7 +391,7 @@ class DXILTranslateMetadataLegacy : public ModulePass { getAnalysis().getResourceMap(); const dxil::Resources &MDResources = getAnalysis().getDXILResource(); - const ComputedShaderFlags &ShaderFlags = + const ModuleShaderFlags &ShaderFlags = getAnalysis().getShaderFlags(); dxil::ModuleMetadataInfo MMDI = getAnalysis().getModuleMetadata(); diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll new file mode 100644 index 0000000000000..02a4c2090499a --- /dev/null +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll @@ -0,0 +1,16 @@ +; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s + +target triple = "dxil-pc-shadermodel6.7-library" +define double @div(double %a, double %b) #0 { + %res = fdiv double %a, %b + ret double %res +} + +attributes #0 = { convergent norecurse nounwind "hlsl.export"} + +; CHECK: - Name: SFI0 +; CHECK-NEXT: Size: 8 +; CHECK-NEXT: Flags: +; CHECK: Doubles: true +; CHECK: DX11_1_DoubleExtensions: true + diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll index a8d5f9c78f0b4..6332ef806a0d8 100644 --- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll @@ -1,27 +1,45 @@ ; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s -; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC target triple = "dxil-pc-shadermodel6.7-library" -; CHECK: ; Shader Flags Value: 0x00000044 +; CHECK: ; Combined Shader Flags for Module +; CHECK-NEXT: ; Shader Flags Value: 0x00000044 + ; CHECK: ; Note: shader requires additional functionality: ; CHECK-NEXT: ; Double-precision floating point ; CHECK-NEXT: ; Double-precision extensions for 11.1 ; CHECK-NEXT: ; Note: extra DXIL module flags: -; CHECK-NEXT: {{^;$}} -define double @div(double %a, double %b) #0 { +; CHECK-NEXT: ; +; CHECK-NEXT: ; Shader Flags for Module Functions + +; CHECK: ; Function test_fdiv_double : 0x00000044 +define double @test_fdiv_double(double %a, double %b) #0 { %res = fdiv double %a, %b ret double %res } -attributes #0 = { convergent norecurse nounwind "hlsl.export"} +; CHECK: ; Function test_uitofp_i64 : 0x00000044 +define double @test_uitofp_i64(i64 %a) #0 { + %r = uitofp i64 %a to double + ret double %r +} + +; CHECK: ; Function test_sitofp_i64 : 0x00000044 +define double @test_sitofp_i64(i64 %a) #0 { + %r = sitofp i64 %a to double + ret double %r +} -; DXC: - Name: SFI0 -; DXC-NEXT: Size: 8 -; DXC-NEXT: Flags: -; DXC-NEXT: Doubles: true -; DXC-NOT: {{[A-Za-z]+: +true}} -; DXC: DX11_1_DoubleExtensions: true -; DXC-NOT: {{[A-Za-z]+: +true}} -; DXC: NextUnusedBit: false -; DXC: ... +; CHECK: ; Function test_fptoui_i32 : 0x00000044 +define i32 @test_fptoui_i32(double %a) #0 { + %r = fptoui double %a to i32 + ret i32 %r +} + +; CHECK: ; Function test_fptosi_i64 : 0x00000044 +define i64 @test_fptosi_i64(double %a) #0 { + %r = fptosi double %a to i64 + ret i64 %r +} + +attributes #0 = { convergent norecurse nounwind "hlsl.export"} diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll index e9b44240e10b9..1c131f0774938 100644 --- a/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll @@ -3,11 +3,15 @@ target triple = "dxil-pc-shadermodel6.7-library" -; CHECK: ; Shader Flags Value: 0x00000004 -; CHECK: ; Note: shader requires additional functionality: -; CHECK-NEXT: ; Double-precision floating point -; CHECK-NEXT: ; Note: extra DXIL module flags: -; CHECK-NEXT: {{^;$}} +;CHECK: ; Combined Shader Flags for Module +;CHECK-NEXT: ; Shader Flags Value: 0x00000004 +;CHECK-NEXT: ; +;CHECK-NEXT: ; Note: shader requires additional functionality: +;CHECK-NEXT: ; Double-precision floating point +;CHECK-NEXT: ; Note: extra DXIL module flags: +;CHECK-NEXT: ; +;CHECK-NEXT: ; Shader Flags for Module Functions +;CHECK-NEXT: ; Function add : 0x00000004 define double @add(double %a, double %b) #0 { %sum = fadd double %a, %b diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll index f7baa1b64f9cd..f99d4fca84da2 100644 --- a/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll @@ -2,7 +2,12 @@ target triple = "dxil-pc-shadermodel6.7-library" -; CHECK: ; Shader Flags Value: 0x00000000 +;CHECK: ; Combined Shader Flags for Module +;CHECK-NEXT: ; Shader Flags Value: 0x00000000 +;CHECK-NEXT: ; +;CHECK-NEXT: ; Shader Flags for Module Functions +;CHECK-NEXT: ; Function add : 0x00000000 + define i32 @add(i32 %a, i32 %b) { %sum = add i32 %a, %b ret i32 %sum