Skip to content

Commit b799e9d

Browse files
authored
[DX] Support pipeline state masks (#66425)
The DXContainer pipeline state information encodes a bunch of mask vectors that are used to track things about the inputs and outputs from each shader. This adds support for reading and writing them throught he YAML test interfaces. The writing logic in MC is extremely primitive and we'll want to revisit the API for that, but since I'm not sure how we'll want to generate the mask bits from DXIL during code generation I didn't want to spend too much time on the API. Fixes #59479
1 parent 1b18e98 commit b799e9d

29 files changed

+1050
-84
lines changed

llvm/include/llvm/MC/DXContainerPSVInfo.h

+14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/BinaryFormat/DXContainer.h"
1515
#include "llvm/TargetParser/Triple.h"
1616

17+
#include <array>
1718
#include <numeric>
1819
#include <stdint.h>
1920

@@ -51,6 +52,19 @@ struct PSVRuntimeInfo {
5152
SmallVector<PSVSignatureElement> OutputElements;
5253
SmallVector<PSVSignatureElement> PatchOrPrimElements;
5354

55+
// TODO: Make this interface user-friendly.
56+
// The interface here is bad, and we'll want to change this in the future. We
57+
// probably will want to build out these mask vectors as vectors of bools and
58+
// have this utility object convert them to the bit masks. I don't want to
59+
// over-engineer this API now since we don't know what the data coming in to
60+
// feed it will look like, so I kept it extremely simple for the immediate use
61+
// case.
62+
std::array<SmallVector<uint32_t>, 4> OutputVectorMasks;
63+
SmallVector<uint32_t> PatchOrPrimMasks;
64+
std::array<SmallVector<uint32_t>, 4> InputOutputMap;
65+
SmallVector<uint32_t> InputPatchMap;
66+
SmallVector<uint32_t> PatchOutputMap;
67+
5468
// Serialize PSVInfo into the provided raw_ostream. The version field
5569
// specifies the data version to encode, the default value specifies encoding
5670
// the highest supported version.

llvm/include/llvm/Object/DXContainer.h

+76-2
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,25 @@
2121
#include "llvm/Support/Error.h"
2222
#include "llvm/Support/MemoryBufferRef.h"
2323
#include "llvm/TargetParser/Triple.h"
24+
#include <array>
2425
#include <variant>
2526

2627
namespace llvm {
2728
namespace object {
2829

2930
namespace DirectX {
31+
32+
namespace detail {
33+
template <typename T>
34+
std::enable_if_t<std::is_arithmetic<T>::value, void> swapBytes(T &value) {
35+
sys::swapByteOrder(value);
36+
}
37+
38+
template <typename T>
39+
std::enable_if_t<std::is_class<T>::value, void> swapBytes(T &value) {
40+
value.swapBytes();
41+
}
42+
} // namespace detail
3043
class PSVRuntimeInfo {
3144

3245
// This class provides a view into the underlying resource array. The Resource
@@ -35,7 +48,7 @@ class PSVRuntimeInfo {
3548
// swaps it as appropriate.
3649
template <typename T> struct ViewArray {
3750
StringRef Data;
38-
uint32_t Stride; // size of each element in the list.
51+
uint32_t Stride = sizeof(T); // size of each element in the list.
3952

4053
ViewArray() = default;
4154
ViewArray(StringRef D, size_t S) : Data(D), Stride(S) {}
@@ -65,7 +78,7 @@ class PSVRuntimeInfo {
6578
memcpy(static_cast<void *>(&Val), Current,
6679
std::min(Stride, MaxStride()));
6780
if (sys::IsBigEndianHost)
68-
Val.swapBytes();
81+
detail::swapBytes(Val);
6982
return Val;
7083
}
7184

@@ -120,6 +133,12 @@ class PSVRuntimeInfo {
120133
SigElementArray SigOutputElements;
121134
SigElementArray SigPatchOrPrimElements;
122135

136+
std::array<ViewArray<uint32_t>, 4> OutputVectorMasks;
137+
ViewArray<uint32_t> PatchOrPrimMasks;
138+
std::array<ViewArray<uint32_t>, 4> InputOutputMap;
139+
ViewArray<uint32_t> InputPatchMap;
140+
ViewArray<uint32_t> PatchOutputMap;
141+
123142
public:
124143
PSVRuntimeInfo(StringRef D) : Data(D), Size(0) {}
125144

@@ -140,6 +159,22 @@ class PSVRuntimeInfo {
140159

141160
const InfoStruct &getInfo() const { return BasicInfo; }
142161

162+
template <typename T> const T *getInfoAs() const {
163+
if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
164+
return static_cast<const T *>(P);
165+
if (std::is_same<T, dxbc::PSV::v2::RuntimeInfo>::value)
166+
return nullptr;
167+
168+
if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
169+
return static_cast<const T *>(P);
170+
if (std::is_same<T, dxbc::PSV::v1::RuntimeInfo>::value)
171+
return nullptr;
172+
173+
if (const auto *P = std::get_if<dxbc::PSV::v0::RuntimeInfo>(&BasicInfo))
174+
return static_cast<const T *>(P);
175+
return nullptr;
176+
}
177+
143178
StringRef getStringTable() const { return StringTable; }
144179
ArrayRef<uint32_t> getSemanticIndexTable() const {
145180
return SemanticIndexTable;
@@ -155,7 +190,46 @@ class PSVRuntimeInfo {
155190
return SigPatchOrPrimElements;
156191
}
157192

193+
ViewArray<uint32_t> getOutputVectorMasks(size_t Idx) const {
194+
assert(Idx < 4);
195+
return OutputVectorMasks[Idx];
196+
}
197+
198+
ViewArray<uint32_t> getPatchOrPrimMasks() const { return PatchOrPrimMasks; }
199+
200+
ViewArray<uint32_t> getInputOutputMap(size_t Idx) const {
201+
assert(Idx < 4);
202+
return InputOutputMap[Idx];
203+
}
204+
205+
ViewArray<uint32_t> getInputPatchMap() const { return InputPatchMap; }
206+
ViewArray<uint32_t> getPatchOutputMap() const { return PatchOutputMap; }
207+
158208
uint32_t getSigElementStride() const { return SigInputElements.Stride; }
209+
210+
bool usesViewID() const {
211+
if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
212+
return P->UsesViewID != 0;
213+
return false;
214+
}
215+
216+
uint8_t getInputVectorCount() const {
217+
if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
218+
return P->SigInputVectors;
219+
return 0;
220+
}
221+
222+
ArrayRef<uint8_t> getOutputVectorCounts() const {
223+
if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
224+
return ArrayRef<uint8_t>(P->SigOutputVectors);
225+
return ArrayRef<uint8_t>();
226+
}
227+
228+
uint8_t getPatchConstOrPrimVectorCount() const {
229+
if (const auto *P = getInfoAs<dxbc::PSV::v1::RuntimeInfo>())
230+
return P->GeomData.SigPatchConstOrPrimVectors;
231+
return 0;
232+
}
159233
};
160234

161235
} // namespace DirectX

llvm/include/llvm/ObjectYAML/DXContainerYAML.h

+9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/BinaryFormat/DXContainer.h"
2020
#include "llvm/ObjectYAML/YAML.h"
2121
#include "llvm/Support/YAMLTraits.h"
22+
#include <array>
2223
#include <cstdint>
2324
#include <optional>
2425
#include <string>
@@ -113,6 +114,13 @@ struct PSVInfo {
113114
SmallVector<SignatureElement> SigOutputElements;
114115
SmallVector<SignatureElement> SigPatchOrPrimElements;
115116

117+
using MaskVector = SmallVector<llvm::yaml::Hex32>;
118+
std::array<MaskVector, 4> OutputVectorMasks;
119+
MaskVector PatchOrPrimMasks;
120+
std::array<MaskVector, 4> InputOutputMap;
121+
MaskVector InputPatchMap;
122+
MaskVector PatchOutputMap;
123+
116124
void mapInfoForVersion(yaml::IO &IO);
117125

118126
PSVInfo();
@@ -143,6 +151,7 @@ struct Object {
143151
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::DXContainerYAML::Part)
144152
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::DXContainerYAML::ResourceBindInfo)
145153
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::DXContainerYAML::SignatureElement)
154+
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::DXContainerYAML::PSVInfo::MaskVector)
146155
LLVM_YAML_DECLARE_ENUM_TRAITS(llvm::dxbc::PSV::SemanticKind)
147156
LLVM_YAML_DECLARE_ENUM_TRAITS(llvm::dxbc::PSV::ComponentType)
148157
LLVM_YAML_DECLARE_ENUM_TRAITS(llvm::dxbc::PSV::InterpolationMode)

llvm/include/llvm/Support/EndianStream.h

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ namespace support {
2525

2626
namespace endian {
2727

28+
template <typename value_type>
29+
inline void write_array(raw_ostream &os, ArrayRef<value_type> values,
30+
endianness endian) {
31+
for (const auto orig : values) {
32+
value_type value = byte_swap<value_type>(orig, endian);
33+
os.write((const char *)&value, sizeof(value_type));
34+
}
35+
}
36+
2837
template <typename value_type>
2938
inline void write(raw_ostream &os, value_type value, endianness endian) {
3039
value = byte_swap<value_type>(value, endian);

llvm/lib/MC/DXContainerPSVInfo.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,17 @@ void PSVRuntimeInfo::write(raw_ostream &OS, uint32_t Version) const {
147147
OS.write(reinterpret_cast<const char *>(&SignatureElements[0]),
148148
SignatureElements.size() * sizeof(v0::SignatureElement));
149149
}
150+
151+
for (const auto &MaskVector : OutputVectorMasks)
152+
support::endian::write_array(OS, ArrayRef<uint32_t>(MaskVector),
153+
support::little);
154+
support::endian::write_array(OS, ArrayRef<uint32_t>(PatchOrPrimMasks),
155+
support::little);
156+
for (const auto &MaskVector : InputOutputMap)
157+
support::endian::write_array(OS, ArrayRef<uint32_t>(MaskVector),
158+
support::little);
159+
support::endian::write_array(OS, ArrayRef<uint32_t>(InputPatchMap),
160+
support::little);
161+
support::endian::write_array(OS, ArrayRef<uint32_t>(PatchOutputMap),
162+
support::little);
150163
}

llvm/lib/Object/DXContainer.cpp

+62
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,68 @@ Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) {
321321
Current += PSize;
322322
}
323323

324+
ArrayRef<uint8_t> OutputVectorCounts = getOutputVectorCounts();
325+
uint8_t PatchConstOrPrimVectorCount = getPatchConstOrPrimVectorCount();
326+
uint8_t InputVectorCount = getInputVectorCount();
327+
328+
auto maskDwordSize = [](uint8_t Vector) {
329+
return (static_cast<uint32_t>(Vector) + 7) >> 3;
330+
};
331+
332+
auto mapTableSize = [maskDwordSize](uint8_t X, uint8_t Y) {
333+
return maskDwordSize(Y) * X * 4;
334+
};
335+
336+
if (usesViewID()) {
337+
for (uint32_t I = 0; I < OutputVectorCounts.size(); ++I) {
338+
// The vector mask is one bit per component and 4 components per vector.
339+
// We can compute the number of dwords required by rounding up to the next
340+
// multiple of 8.
341+
uint32_t NumDwords =
342+
maskDwordSize(static_cast<uint32_t>(OutputVectorCounts[I]));
343+
size_t NumBytes = NumDwords * sizeof(uint32_t);
344+
OutputVectorMasks[I].Data = Data.substr(Current - Data.begin(), NumBytes);
345+
Current += NumBytes;
346+
}
347+
348+
if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0) {
349+
uint32_t NumDwords = maskDwordSize(PatchConstOrPrimVectorCount);
350+
size_t NumBytes = NumDwords * sizeof(uint32_t);
351+
PatchOrPrimMasks.Data = Data.substr(Current - Data.begin(), NumBytes);
352+
Current += NumBytes;
353+
}
354+
}
355+
356+
// Input/Output mapping table
357+
for (uint32_t I = 0; I < OutputVectorCounts.size(); ++I) {
358+
if (InputVectorCount == 0 || OutputVectorCounts[I] == 0)
359+
continue;
360+
uint32_t NumDwords = mapTableSize(InputVectorCount, OutputVectorCounts[I]);
361+
size_t NumBytes = NumDwords * sizeof(uint32_t);
362+
InputOutputMap[I].Data = Data.substr(Current - Data.begin(), NumBytes);
363+
Current += NumBytes;
364+
}
365+
366+
// Hull shader: Input/Patch mapping table
367+
if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0 &&
368+
InputVectorCount > 0) {
369+
uint32_t NumDwords =
370+
mapTableSize(InputVectorCount, PatchConstOrPrimVectorCount);
371+
size_t NumBytes = NumDwords * sizeof(uint32_t);
372+
InputPatchMap.Data = Data.substr(Current - Data.begin(), NumBytes);
373+
Current += NumBytes;
374+
}
375+
376+
// Domain Shader: Patch/Output mapping table
377+
if (ShaderStage == Triple::Domain && PatchConstOrPrimVectorCount > 0 &&
378+
OutputVectorCounts[0] > 0) {
379+
uint32_t NumDwords =
380+
mapTableSize(PatchConstOrPrimVectorCount, OutputVectorCounts[0]);
381+
size_t NumBytes = NumDwords * sizeof(uint32_t);
382+
PatchOutputMap.Data = Data.substr(Current - Data.begin(), NumBytes);
383+
Current += NumBytes;
384+
}
385+
324386
return Error::success();
325387
}
326388

llvm/lib/ObjectYAML/DXContainerEmitter.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,26 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
219219
El.Allocated, El.Kind, El.Type, El.Mode, El.DynamicMask,
220220
El.Stream});
221221

222+
static_assert(PSV.OutputVectorMasks.size() == PSV.InputOutputMap.size());
223+
for (unsigned I = 0; I < PSV.OutputVectorMasks.size(); ++I) {
224+
PSV.OutputVectorMasks[I].insert(PSV.OutputVectorMasks[I].begin(),
225+
P.Info->OutputVectorMasks[I].begin(),
226+
P.Info->OutputVectorMasks[I].end());
227+
PSV.InputOutputMap[I].insert(PSV.InputOutputMap[I].begin(),
228+
P.Info->InputOutputMap[I].begin(),
229+
P.Info->InputOutputMap[I].end());
230+
}
231+
232+
PSV.PatchOrPrimMasks.insert(PSV.PatchOrPrimMasks.begin(),
233+
P.Info->PatchOrPrimMasks.begin(),
234+
P.Info->PatchOrPrimMasks.end());
235+
PSV.InputPatchMap.insert(PSV.InputPatchMap.begin(),
236+
P.Info->InputPatchMap.begin(),
237+
P.Info->InputPatchMap.end());
238+
PSV.PatchOutputMap.insert(PSV.PatchOutputMap.begin(),
239+
P.Info->PatchOutputMap.begin(),
240+
P.Info->PatchOutputMap.end());
241+
222242
PSV.finalize(static_cast<Triple::EnvironmentType>(
223243
Triple::Pixel + P.Info->Info.ShaderStage));
224244
PSV.write(OS, P.Info->Version);

llvm/lib/ObjectYAML/DXContainerYAML.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,24 @@ void MappingTraits<DXContainerYAML::PSVInfo>::mapping(
139139
IO.mapRequired("SigInputElements", PSV.SigInputElements);
140140
IO.mapRequired("SigOutputElements", PSV.SigOutputElements);
141141
IO.mapRequired("SigPatchOrPrimElements", PSV.SigPatchOrPrimElements);
142+
143+
Triple::EnvironmentType Stage = dxbc::getShaderStage(PSV.Info.ShaderStage);
144+
if (PSV.Info.UsesViewID) {
145+
MutableArrayRef<SmallVector<llvm::yaml::Hex32>> MutableOutMasks(
146+
PSV.OutputVectorMasks);
147+
IO.mapRequired("OutputVectorMasks", MutableOutMasks);
148+
if (Stage == Triple::EnvironmentType::Hull)
149+
IO.mapRequired("PatchOrPrimMasks", PSV.PatchOrPrimMasks);
150+
}
151+
MutableArrayRef<SmallVector<llvm::yaml::Hex32>> MutableIOMap(
152+
PSV.InputOutputMap);
153+
IO.mapRequired("InputOutputMap", MutableIOMap);
154+
155+
if (Stage == Triple::EnvironmentType::Hull)
156+
IO.mapRequired("InputPatchMap", PSV.InputPatchMap);
157+
158+
if (Stage == Triple::EnvironmentType::Domain)
159+
IO.mapRequired("PatchOutputMap", PSV.PatchOutputMap);
142160
}
143161

144162
void MappingTraits<DXContainerYAML::Part>::mapping(IO &IO,

0 commit comments

Comments
 (0)