Skip to content

Commit cde54df

Browse files
cooppCooper Partin
and
Cooper Partin
authored
Add support for PSV EntryFunctionName (#84409)
This change introduces a version 3 of the PSV data that includes support for the name of the entry function as an offset into StringTable data to a null-terminated utf-8 string. Additional tests were added to ensure that the new value was properly serialized/deserialized from object data. Fixes #80175 --------- Co-authored-by: Cooper Partin <[email protected]>
1 parent 6280681 commit cde54df

18 files changed

+940
-52
lines changed

llvm/include/llvm/BinaryFormat/DXContainer.h

+13
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,19 @@ struct ResourceBindInfo : public v0::ResourceBindInfo {
424424
};
425425

426426
} // namespace v2
427+
428+
namespace v3 {
429+
struct RuntimeInfo : public v2::RuntimeInfo {
430+
uint32_t EntryNameOffset;
431+
432+
void swapBytes() { sys::swapByteOrder(EntryNameOffset); }
433+
434+
void swapBytes(Triple::EnvironmentType Stage) {
435+
v2::RuntimeInfo::swapBytes(Stage);
436+
}
437+
};
438+
439+
} // namespace v3
427440
} // namespace PSV
428441

429442
#define COMPONENT_PRECISION(Val, Enum) Enum = Val,

llvm/include/llvm/MC/DXContainerPSVInfo.h

+11-14
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
#ifndef LLVM_MC_DXCONTAINERPSVINFO_H
1010
#define LLVM_MC_DXCONTAINERPSVINFO_H
1111

12+
#include "llvm/ADT/ArrayRef.h"
1213
#include "llvm/ADT/SmallVector.h"
1314
#include "llvm/ADT/StringRef.h"
1415
#include "llvm/BinaryFormat/DXContainer.h"
16+
#include "llvm/MC/StringTableBuilder.h"
1517
#include "llvm/TargetParser/Triple.h"
1618

1719
#include <array>
@@ -45,8 +47,9 @@ struct PSVSignatureElement {
4547
// modifiable format, and can be used to serialize the data back into valid PSV
4648
// RuntimeInfo.
4749
struct PSVRuntimeInfo {
50+
PSVRuntimeInfo() : DXConStrTabBuilder(StringTableBuilder::DXContainer) {}
4851
bool IsFinalized = false;
49-
dxbc::PSV::v2::RuntimeInfo BaseData;
52+
dxbc::PSV::v3::RuntimeInfo BaseData;
5053
SmallVector<dxbc::PSV::v2::ResourceBindInfo> Resources;
5154
SmallVector<PSVSignatureElement> InputElements;
5255
SmallVector<PSVSignatureElement> OutputElements;
@@ -64,26 +67,20 @@ struct PSVRuntimeInfo {
6467
std::array<SmallVector<uint32_t>, 4> InputOutputMap;
6568
SmallVector<uint32_t> InputPatchMap;
6669
SmallVector<uint32_t> PatchOutputMap;
70+
llvm::StringRef EntryName;
6771

6872
// Serialize PSVInfo into the provided raw_ostream. The version field
6973
// specifies the data version to encode, the default value specifies encoding
7074
// the highest supported version.
7175
void write(raw_ostream &OS,
7276
uint32_t Version = std::numeric_limits<uint32_t>::max()) const;
7377

74-
void finalize(Triple::EnvironmentType Stage) {
75-
IsFinalized = true;
76-
BaseData.SigInputElements = static_cast<uint32_t>(InputElements.size());
77-
BaseData.SigOutputElements = static_cast<uint32_t>(OutputElements.size());
78-
BaseData.SigPatchOrPrimElements =
79-
static_cast<uint32_t>(PatchOrPrimElements.size());
80-
if (!sys::IsBigEndianHost)
81-
return;
82-
BaseData.swapBytes();
83-
BaseData.swapBytes(Stage);
84-
for (auto &Res : Resources)
85-
Res.swapBytes();
86-
}
78+
void finalize(Triple::EnvironmentType Stage);
79+
80+
private:
81+
SmallVector<uint32_t, 64> IndexBuffer;
82+
SmallVector<llvm::dxbc::PSV::v0::SignatureElement, 32> SignatureElements;
83+
StringTableBuilder DXConStrTabBuilder;
8784
};
8885

8986
class Signature {

llvm/include/llvm/MC/StringTableBuilder.h

+2-6
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,8 @@ class StringTableBuilder {
7474
/// Check if a string is contained in the string table. Since this class
7575
/// doesn't store the string values, this function can be used to check if
7676
/// storage needs to be done prior to adding the string.
77-
bool contains(StringRef S) const {
78-
return contains(CachedHashStringRef(S));
79-
}
80-
bool contains(CachedHashStringRef S) const {
81-
return StringIndexMap.count(S);
82-
}
77+
bool contains(StringRef S) const { return contains(CachedHashStringRef(S)); }
78+
bool contains(CachedHashStringRef S) const { return StringIndexMap.count(S); }
8379

8480
size_t getSize() const { return Size; }
8581
void clear();

llvm/include/llvm/Object/DXContainer.h

+12-4
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ class PSVRuntimeInfo {
125125
uint32_t Size;
126126
using InfoStruct =
127127
std::variant<std::monostate, dxbc::PSV::v0::RuntimeInfo,
128-
dxbc::PSV::v1::RuntimeInfo, dxbc::PSV::v2::RuntimeInfo>;
128+
dxbc::PSV::v1::RuntimeInfo, dxbc::PSV::v2::RuntimeInfo,
129+
dxbc::PSV::v3::RuntimeInfo>;
129130
InfoStruct BasicInfo;
130131
ResourceArray Resources;
131132
StringRef StringTable;
@@ -151,16 +152,23 @@ class PSVRuntimeInfo {
151152
ResourceArray getResources() const { return Resources; }
152153

153154
uint32_t getVersion() const {
154-
return Size >= sizeof(dxbc::PSV::v2::RuntimeInfo)
155-
? 2
156-
: (Size >= sizeof(dxbc::PSV::v1::RuntimeInfo) ? 1 : 0);
155+
return Size >= sizeof(dxbc::PSV::v3::RuntimeInfo)
156+
? 3
157+
: (Size >= sizeof(dxbc::PSV::v2::RuntimeInfo) ? 2
158+
: (Size >= sizeof(dxbc::PSV::v1::RuntimeInfo)) ? 1
159+
: 0);
157160
}
158161

159162
uint32_t getResourceStride() const { return Resources.Stride; }
160163

161164
const InfoStruct &getInfo() const { return BasicInfo; }
162165

163166
template <typename T> const T *getInfoAs() const {
167+
if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
168+
return static_cast<const T *>(P);
169+
if (std::is_same<T, dxbc::PSV::v3::RuntimeInfo>::value)
170+
return nullptr;
171+
164172
if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
165173
return static_cast<const T *>(P);
166174
if (std::is_same<T, dxbc::PSV::v2::RuntimeInfo>::value)

llvm/include/llvm/ObjectYAML/DXContainerYAML.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct PSVInfo {
107107
// the format.
108108
uint32_t Version;
109109

110-
dxbc::PSV::v2::RuntimeInfo Info;
110+
dxbc::PSV::v3::RuntimeInfo Info;
111111
uint32_t ResourceStride;
112112
SmallVector<ResourceBindInfo> Resources;
113113
SmallVector<SignatureElement> SigInputElements;
@@ -121,12 +121,15 @@ struct PSVInfo {
121121
MaskVector InputPatchMap;
122122
MaskVector PatchOutputMap;
123123

124+
StringRef EntryName;
125+
124126
void mapInfoForVersion(yaml::IO &IO);
125127

126128
PSVInfo();
127129
PSVInfo(const dxbc::PSV::v0::RuntimeInfo *P, uint16_t Stage);
128130
PSVInfo(const dxbc::PSV::v1::RuntimeInfo *P);
129131
PSVInfo(const dxbc::PSV::v2::RuntimeInfo *P);
132+
PSVInfo(const dxbc::PSV::v3::RuntimeInfo *P, StringRef StringTable);
130133
};
131134

132135
struct SignatureParameter {

llvm/lib/MC/DXContainerPSVInfo.cpp

+50-25
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,18 @@ void PSVRuntimeInfo::write(raw_ostream &OS, uint32_t Version) const {
8181
BindingSize = sizeof(dxbc::PSV::v0::ResourceBindInfo);
8282
break;
8383
case 2:
84-
default:
8584
InfoSize = sizeof(dxbc::PSV::v2::RuntimeInfo);
8685
BindingSize = sizeof(dxbc::PSV::v2::ResourceBindInfo);
86+
break;
87+
case 3:
88+
default:
89+
InfoSize = sizeof(dxbc::PSV::v3::RuntimeInfo);
90+
BindingSize = sizeof(dxbc::PSV::v2::ResourceBindInfo);
8791
}
88-
// Write the size of the info.
8992

93+
// Write the size of the info.
9094
support::endian::write(OS, InfoSize, llvm::endianness::little);
95+
9196
// Write the info itself.
9297
OS.write(reinterpret_cast<const char *>(&BaseData), InfoSize);
9398

@@ -104,32 +109,12 @@ void PSVRuntimeInfo::write(raw_ostream &OS, uint32_t Version) const {
104109
if (Version == 0)
105110
return;
106111

107-
StringTableBuilder StrTabBuilder((StringTableBuilder::DXContainer));
108-
SmallVector<uint32_t, 64> IndexBuffer;
109-
SmallVector<v0::SignatureElement, 32> SignatureElements;
110-
SmallVector<StringRef, 32> SemanticNames;
111-
112-
ProcessElementList(StrTabBuilder, IndexBuffer, SignatureElements,
113-
SemanticNames, InputElements);
114-
ProcessElementList(StrTabBuilder, IndexBuffer, SignatureElements,
115-
SemanticNames, OutputElements);
116-
ProcessElementList(StrTabBuilder, IndexBuffer, SignatureElements,
117-
SemanticNames, PatchOrPrimElements);
118-
119-
StrTabBuilder.finalize();
120-
for (auto ElAndName : zip(SignatureElements, SemanticNames)) {
121-
v0::SignatureElement &El = std::get<0>(ElAndName);
122-
StringRef Name = std::get<1>(ElAndName);
123-
El.NameOffset = static_cast<uint32_t>(StrTabBuilder.getOffset(Name));
124-
if (sys::IsBigEndianHost)
125-
El.swapBytes();
126-
}
127-
128-
support::endian::write(OS, static_cast<uint32_t>(StrTabBuilder.getSize()),
112+
support::endian::write(OS,
113+
static_cast<uint32_t>(DXConStrTabBuilder.getSize()),
129114
llvm::endianness::little);
130115

131116
// Write the string table.
132-
StrTabBuilder.write(OS);
117+
DXConStrTabBuilder.write(OS);
133118

134119
// Write the index table size, then table.
135120
support::endian::write(OS, static_cast<uint32_t>(IndexBuffer.size()),
@@ -162,6 +147,46 @@ void PSVRuntimeInfo::write(raw_ostream &OS, uint32_t Version) const {
162147
llvm::endianness::little);
163148
}
164149

150+
void PSVRuntimeInfo::finalize(Triple::EnvironmentType Stage) {
151+
IsFinalized = true;
152+
BaseData.SigInputElements = static_cast<uint32_t>(InputElements.size());
153+
BaseData.SigOutputElements = static_cast<uint32_t>(OutputElements.size());
154+
BaseData.SigPatchOrPrimElements =
155+
static_cast<uint32_t>(PatchOrPrimElements.size());
156+
157+
SmallVector<StringRef, 32> SemanticNames;
158+
159+
// Build a string table and set associated offsets to be written when
160+
// write() is called
161+
ProcessElementList(DXConStrTabBuilder, IndexBuffer, SignatureElements,
162+
SemanticNames, InputElements);
163+
ProcessElementList(DXConStrTabBuilder, IndexBuffer, SignatureElements,
164+
SemanticNames, OutputElements);
165+
ProcessElementList(DXConStrTabBuilder, IndexBuffer, SignatureElements,
166+
SemanticNames, PatchOrPrimElements);
167+
168+
DXConStrTabBuilder.add(EntryName);
169+
170+
DXConStrTabBuilder.finalize();
171+
for (auto ElAndName : zip(SignatureElements, SemanticNames)) {
172+
llvm::dxbc::PSV::v0::SignatureElement &El = std::get<0>(ElAndName);
173+
StringRef Name = std::get<1>(ElAndName);
174+
El.NameOffset = static_cast<uint32_t>(DXConStrTabBuilder.getOffset(Name));
175+
if (sys::IsBigEndianHost)
176+
El.swapBytes();
177+
}
178+
179+
BaseData.EntryNameOffset =
180+
static_cast<uint32_t>(DXConStrTabBuilder.getOffset(EntryName));
181+
182+
if (!sys::IsBigEndianHost)
183+
return;
184+
BaseData.swapBytes();
185+
BaseData.swapBytes(Stage);
186+
for (auto &Res : Resources)
187+
Res.swapBytes();
188+
}
189+
165190
void Signature::write(raw_ostream &OS) {
166191
SmallVector<dxbc::ProgramSignatureElement> SigParams;
167192
SigParams.reserve(Params.size());

llvm/lib/Object/DXContainer.cpp

+14-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,14 @@ Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) {
247247
const uint32_t PSVVersion = getVersion();
248248

249249
// Detect the PSVVersion by looking at the size field.
250-
if (PSVVersion == 2) {
250+
if (PSVVersion == 3) {
251+
v3::RuntimeInfo Info;
252+
if (Error Err = readStruct(PSVInfoData, Current, Info))
253+
return Err;
254+
if (sys::IsBigEndianHost)
255+
Info.swapBytes(ShaderStage);
256+
BasicInfo = Info;
257+
} else if (PSVVersion == 2) {
251258
v2::RuntimeInfo Info;
252259
if (Error Err = readStruct(PSVInfoData, Current, Info))
253260
return Err;
@@ -425,6 +432,8 @@ Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) {
425432
}
426433

427434
uint8_t DirectX::PSVRuntimeInfo::getSigInputCount() const {
435+
if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
436+
return P->SigInputElements;
428437
if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
429438
return P->SigInputElements;
430439
if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
@@ -433,6 +442,8 @@ uint8_t DirectX::PSVRuntimeInfo::getSigInputCount() const {
433442
}
434443

435444
uint8_t DirectX::PSVRuntimeInfo::getSigOutputCount() const {
445+
if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
446+
return P->SigOutputElements;
436447
if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
437448
return P->SigOutputElements;
438449
if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))
@@ -441,6 +452,8 @@ uint8_t DirectX::PSVRuntimeInfo::getSigOutputCount() const {
441452
}
442453

443454
uint8_t DirectX::PSVRuntimeInfo::getSigPatchOrPrimCount() const {
455+
if (const auto *P = std::get_if<dxbc::PSV::v3::RuntimeInfo>(&BasicInfo))
456+
return P->SigPatchOrPrimElements;
444457
if (const auto *P = std::get_if<dxbc::PSV::v2::RuntimeInfo>(&BasicInfo))
445458
return P->SigPatchOrPrimElements;
446459
if (const auto *P = std::get_if<dxbc::PSV::v1::RuntimeInfo>(&BasicInfo))

llvm/lib/ObjectYAML/DXContainerEmitter.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
198198
if (!P.Info.has_value())
199199
continue;
200200
mcdxbc::PSVRuntimeInfo PSV;
201-
memcpy(&PSV.BaseData, &P.Info->Info, sizeof(dxbc::PSV::v2::RuntimeInfo));
201+
memcpy(&PSV.BaseData, &P.Info->Info, sizeof(dxbc::PSV::v3::RuntimeInfo));
202202
PSV.Resources = P.Info->Resources;
203+
PSV.EntryName = P.Info->EntryName;
203204

204205
for (auto El : P.Info->SigInputElements)
205206
PSV.InputElements.push_back(mcdxbc::PSVSignatureElement{

llvm/lib/ObjectYAML/DXContainerYAML.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ DXContainerYAML::PSVInfo::PSVInfo(const dxbc::PSV::v2::RuntimeInfo *P)
7474
memcpy(&Info, P, sizeof(dxbc::PSV::v2::RuntimeInfo));
7575
}
7676

77+
DXContainerYAML::PSVInfo::PSVInfo(const dxbc::PSV::v3::RuntimeInfo *P,
78+
StringRef StringTable)
79+
: Version(3),
80+
EntryName(StringTable.substr(P->EntryNameOffset,
81+
StringTable.find('\0', P->EntryNameOffset) -
82+
P->EntryNameOffset)) {
83+
memset(&Info, 0, sizeof(Info));
84+
memcpy(&Info, P, sizeof(dxbc::PSV::v3::RuntimeInfo));
85+
}
86+
7787
namespace yaml {
7888

7989
void MappingTraits<DXContainerYAML::VersionTuple>::mapping(
@@ -348,6 +358,11 @@ void DXContainerYAML::PSVInfo::mapInfoForVersion(yaml::IO &IO) {
348358
IO.mapRequired("NumThreadsX", Info.NumThreadsX);
349359
IO.mapRequired("NumThreadsY", Info.NumThreadsY);
350360
IO.mapRequired("NumThreadsZ", Info.NumThreadsZ);
361+
362+
if (Version == 2)
363+
return;
364+
365+
IO.mapRequired("EntryName", EntryName);
351366
}
352367

353368
} // namespace llvm

0 commit comments

Comments
 (0)