Skip to content

[DirectX] Infrastructure to collect shader flags for each function #112967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions llvm/lib/Target/DirectX/DXContainerGlobals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ bool DXContainerGlobals::runOnModule(Module &M) {
}

GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
const uint64_t FeatureFlags =
static_cast<uint64_t>(getAnalysis<ShaderFlagsAnalysisWrapper>()
.getShaderFlags()
.getFeatureFlags());
uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>()
.getShaderFlags()
.getCombinedFlags()
.getFeatureFlags();

Constant *FeatureFlagsConstant =
ConstantInt::get(M.getContext(), APInt(64, FeatureFlags));
ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags));
return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
}

Expand Down
86 changes: 71 additions & 15 deletions llvm/lib/Target/DirectX/DXILShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,54 @@

#include "DXILShaderFlags.h"
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"

using namespace llvm;
using namespace llvm::dxil;

static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
Type *Ty = I.getType();
if (Ty->isDoubleTy()) {
Flags.Doubles = true;
static void updateFunctionFlags(ComputedShaderFlags &CSF,
const Instruction &I) {
if (!CSF.Doubles)
CSF.Doubles = I.getType()->isDoubleTy();

if (!CSF.Doubles) {
for (Value *Op : I.operands())
CSF.Doubles |= Op->getType()->isDoubleTy();
}
if (CSF.Doubles) {
switch (I.getOpcode()) {
case Instruction::FDiv:
case Instruction::UIToFP:
case Instruction::SIToFP:
case Instruction::FPToUI:
case Instruction::FPToSI:
Flags.DX11_1_DoubleExtensions = true;
// TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an issue for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an issue for this?

#114554.

// https://github.com/llvm/llvm-project/issues/114554
CSF.DX11_1_DoubleExtensions = true;
break;
}
}
}

ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
ComputedShaderFlags Flags;
for (const auto &F : M)
void ModuleShaderFlags::initialize(const Module &M) {
// Collect shader flags for each of the functions
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
updateFlags(Flags, I);
return Flags;
updateFunctionFlags(CSF, I);
// Insert shader flag mask for function F
FunctionFlags.push_back({&F, CSF});
// Update combined shader flags mask
CombinedSFMask.merge(CSF);
}
llvm::sort(FunctionFlags);
}

void ComputedShaderFlags::print(raw_ostream &OS) const {
Expand All @@ -63,20 +81,58 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
OS << ";\n";
}

/// Return the shader flags mask of the specified function Func.
const ComputedShaderFlags &
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
const auto Iter = llvm::lower_bound(
FunctionFlags, Func,
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
const Function *FindFunc) { return (FSM.first < FindFunc); });
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"No Shader Flags Mask exists for function");
return Iter->second;
}

//===----------------------------------------------------------------------===//
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass

// Provide an explicit template instantiation for the static ID.
AnalysisKey ShaderFlagsAnalysis::Key;

ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M,
ModuleAnalysisManager &AM) {
return ComputedShaderFlags::computeFlags(M);
ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
ModuleAnalysisManager &AM) {
ModuleShaderFlags MSFI;
MSFI.initialize(M);
return MSFI;
}

PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
ModuleAnalysisManager &AM) {
ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M);
Flags.print(OS);
const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
// Print description of combined shader flags for all module functions
OS << "; Combined Shader Flags for Module\n";
FlagsInfo.getCombinedFlags().print(OS);
// Print shader flags mask for each of the module functions
OS << "; Shader Flags for Module Functions\n";
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
auto SFMask = FlagsInfo.getFunctionFlags(&F);
OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
(uint64_t)(SFMask));
}

return PreservedAnalyses::all();
}

//===----------------------------------------------------------------------===//
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass

bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
MSFI.initialize(M);
return false;
}

char ShaderFlagsAnalysisWrapper::ID = 0;

INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
Expand Down
53 changes: 41 additions & 12 deletions llvm/lib/Target/DirectX/DXILShaderFlags.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
#ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
#define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H

#include "llvm/IR/Function.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <memory>

namespace llvm {
class Module;
Expand All @@ -43,15 +45,23 @@ struct ComputedShaderFlags {
constexpr uint64_t getMask(int Bit) const {
return Bit != -1 ? 1ull << Bit : 0;
}

uint64_t getModuleFlags() const {
uint64_t ModuleFlags = 0;
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
ModuleFlags |= FlagName ? getMask(DxilModuleBit) : 0ull;
#include "llvm/BinaryFormat/DXContainerConstants.def"
return ModuleFlags;
}

operator uint64_t() const {
uint64_t FlagValue = 0;
uint64_t FlagValue = getModuleFlags();
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull;
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
FlagValue |= FlagName ? getMask(DxilModuleBit) : 0ull;
#include "llvm/BinaryFormat/DXContainerConstants.def"
return FlagValue;
}

uint64_t getFeatureFlags() const {
uint64_t FeatureFlags = 0;
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
Expand All @@ -60,21 +70,43 @@ struct ComputedShaderFlags {
return FeatureFlags;
}

static ComputedShaderFlags computeFlags(Module &M);
void merge(const uint64_t IVal) {
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
FlagName |= (IVal & getMask(DxilModuleBit));
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
FlagName |= (IVal & getMask(DxilModuleBit));
#include "llvm/BinaryFormat/DXContainerConstants.def"
return;
}

void print(raw_ostream &OS = dbgs()) const;
LLVM_DUMP_METHOD void dump() const { print(); }
};

struct ModuleShaderFlags {
void initialize(const Module &);
const ComputedShaderFlags &getFunctionFlags(const Function *) const;
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }

private:
/// Vector of sorted Function-Shader Flag mask pairs representing properties
/// of each of the functions in the module. Shader Flags of each function
/// represent both module-level and function-level flags
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
};

class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
static AnalysisKey Key;

public:
ShaderFlagsAnalysis() = default;

using Result = ComputedShaderFlags;
using Result = ModuleShaderFlags;

ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM);
ModuleShaderFlags run(Module &M, ModuleAnalysisManager &AM);
};

/// Printer pass for ShaderFlagsAnalysis results.
Expand All @@ -92,19 +124,16 @@ class ShaderFlagsAnalysisPrinter
/// This is required because the passes that will depend on this are codegen
/// passes which run through the legacy pass manager.
class ShaderFlagsAnalysisWrapper : public ModulePass {
ComputedShaderFlags Flags;
ModuleShaderFlags MSFI;

public:
static char ID;

ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}

const ComputedShaderFlags &getShaderFlags() { return Flags; }
const ModuleShaderFlags &getShaderFlags() { return MSFI; }

bool runOnModule(Module &M) override {
Flags = ComputedShaderFlags::computeFlags(M);
return false;
}
bool runOnModule(Module &M) override;

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
Expand Down
36 changes: 17 additions & 19 deletions llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
MDTuple *Properties = nullptr;
if (ShaderFlags != 0) {
SmallVector<Metadata *> MDVals;
// FIXME: ShaderFlagsAnalysis pass needs to collect and provide
// ShaderFlags for each entry function. Currently, ShaderFlags value
// provided by ShaderFlagsAnalysis pass is created by walking *all* the
// function instructions of the module. Is it is correct to use this value
// for metadata of the empty library entry?
MDVals.append(
getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
Properties = MDNode::get(Ctx, MDVals);
Expand All @@ -302,7 +297,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,

static void translateMetadata(Module &M, const DXILResourceMap &DRM,
const Resources &MDResources,
const ComputedShaderFlags &ShaderFlags,
const ModuleShaderFlags &ShaderFlags,
const ModuleMetadataInfo &MMDI) {
LLVMContext &Ctx = M.getContext();
IRBuilder<> IRB(Ctx);
Expand All @@ -318,23 +313,27 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
// See https://github.com/llvm/llvm-project/issues/57928
MDTuple *Signatures = nullptr;

if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
// Get the combined shader flag mask of all functions in the library to be
// used as shader flags mask value associated with top-level library entry
// metadata.
uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
EntryFnMDNodes.emplace_back(
emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
else if (MMDI.EntryPropertyVec.size() > 1) {
emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
} else if (MMDI.EntryPropertyVec.size() > 1) {
M.getContext().diagnose(DiagnosticInfoTranslateMD(
M, "Non-library shader: One and only one entry expected"));
}

for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
// FIXME: ShaderFlagsAnalysis pass needs to collect and provide
// ShaderFlags for each entry function. For now, assume shader flags value
// of entry functions being compiled for lib_* shader profile viz.,
// EntryPro.Entry is 0.
uint64_t EntryShaderFlags =
(MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
: ShaderFlags;
const ComputedShaderFlags &EntrySFMask =
ShaderFlags.getFunctionFlags(EntryProp.Entry);

// If ShaderProfile is Library, mask is already consolidated in the
// top-level library node. Hence it is not emitted.
uint64_t EntryShaderFlags = 0;
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
EntryShaderFlags = EntrySFMask;
if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
M.getContext().diagnose(DiagnosticInfoTranslateMD(
M,
Expand All @@ -361,8 +360,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
ModuleAnalysisManager &MAM) {
const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
const ComputedShaderFlags &ShaderFlags =
MAM.getResult<ShaderFlagsAnalysis>(M);
const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M);
const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);

translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
Expand Down Expand Up @@ -393,7 +391,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
const dxil::Resources &MDResources =
getAnalysis<DXILResourceMDWrapper>().getDXILResource();
const ComputedShaderFlags &ShaderFlags =
const ModuleShaderFlags &ShaderFlags =
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
dxil::ModuleMetadataInfo MMDI =
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s

target triple = "dxil-pc-shadermodel6.7-library"
define double @div(double %a, double %b) #0 {
%res = fdiv double %a, %b
ret double %res
}

attributes #0 = { convergent norecurse nounwind "hlsl.export"}

; CHECK: - Name: SFI0
; CHECK-NEXT: Size: 8
; CHECK-NEXT: Flags:
; CHECK: Doubles: true
; CHECK: DX11_1_DoubleExtensions: true

Loading