Skip to content

Commit f9d7e3e

Browse files
committed
[SPIR-V] Fix structurizer issues
The "topological" sorting was behaving incorrectly in some cases: the exit of a loop could have a lower rank than a node in the loop. This causes issues when structurizing some patterns, and also codegen issues as we could generate BBs in the incorrect order in regard to the SPIR-V spec. Fixing this ordering alone broke other parts of the structurizer, which by luck worked. Had to fix those. Added more test cases, especially to test basic patterns. Signed-off-by: Nathan Gauër <[email protected]>
1 parent 817fd98 commit f9d7e3e

39 files changed

+1048
-606
lines changed

llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
130130
assert(false && "Unhandled terminator type.");
131131
}
132132

133+
AllocaInst *CreateVariable(Function &F, Type *Type,
134+
BasicBlock::iterator Position) {
135+
const DataLayout &DL = F.getDataLayout();
136+
return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
137+
Position);
138+
}
139+
133140
// Run the pass on the given convergence region, ignoring the sub-regions.
134141
// Returns true if the CFG changed, false otherwise.
135142
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
@@ -152,6 +159,9 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
152159
auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
153160
IRBuilder<> Builder(NewExitTarget);
154161

162+
AllocaInst *Variable = CreateVariable(*F, Builder.getInt32Ty(),
163+
F->begin()->getFirstInsertionPt());
164+
155165
// CodeGen output needs to be stable. Using the set as-is would order
156166
// the targets differently depending on the allocation pattern.
157167
// Sorting per basic-block ordering in the function.
@@ -176,18 +186,16 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
176186
std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
177187
for (auto Exit : SortedExits) {
178188
llvm::Value *Value = createExitVariable(Exit, TargetToValue);
189+
IRBuilder<> B2(Exit);
190+
B2.SetInsertPoint(Exit->getFirstInsertionPt());
191+
B2.CreateStore(Value, Variable);
179192
ExitToVariable.emplace_back(std::make_pair(Exit, Value));
180193
}
181194

182-
// Gather the correct value depending on the exit we came from.
183-
llvm::PHINode *node =
184-
Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
185-
for (auto [BB, Value] : ExitToVariable) {
186-
node->addIncoming(Value, BB);
187-
}
195+
llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable);
188196

189197
// Creating the switch to jump to the correct exit target.
190-
llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
198+
llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0],
191199
SortedExitTargets.size() - 1);
192200
for (size_t i = 1; i < SortedExitTargets.size(); i++) {
193201
BasicBlock *BB = SortedExitTargets[i];

llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ BasicBlock *getExitFor(const ConvergenceRegion *CR) {
8787
// Returns the merge block designated by I if I is a merge instruction, nullptr
8888
// otherwise.
8989
BasicBlock *getDesignatedMergeBlock(Instruction *I) {
90+
if (I == nullptr)
91+
return nullptr;
9092
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
9193
if (II == nullptr)
9294
return nullptr;
@@ -102,6 +104,8 @@ BasicBlock *getDesignatedMergeBlock(Instruction *I) {
102104
// Returns the continue block designated by I if I is an OpLoopMerge, nullptr
103105
// otherwise.
104106
BasicBlock *getDesignatedContinueBlock(Instruction *I) {
107+
if (I == nullptr)
108+
return nullptr;
105109
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
106110
if (II == nullptr)
107111
return nullptr;
@@ -447,55 +451,66 @@ class SPIRVStructurizer : public FunctionPass {
447451
// clang-format on
448452
std::vector<Edge>
449453
createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
450-
std::unordered_map<BasicBlock *, BasicBlock *> Seen;
454+
std::unordered_set<BasicBlock *> Seen;
451455
std::vector<Edge> Output;
452456
Output.reserve(Edges.size());
453457

454458
for (auto &[Src, Dst] : Edges) {
455-
auto [iterator, inserted] = Seen.insert({Src, Dst});
456-
if (inserted) {
457-
Output.emplace_back(Src, Dst);
458-
continue;
459+
auto [iterator, inserted] = Seen.insert(Src);
460+
if (!inserted) {
461+
// Src already a source node. Cannot have 2 edges from A to B.
462+
// Creating alias source block.
463+
BasicBlock *NewSrc =
464+
BasicBlock::Create(F.getContext(), "new.src", &F);
465+
replaceBranchTargets(Src, Dst, NewSrc);
466+
// replacePhiTargets(Dst, Src, NewSrc);
467+
IRBuilder<> Builder(NewSrc);
468+
Builder.CreateBr(Dst);
469+
Src = NewSrc;
459470
}
460471

461-
// The exact same edge was already seen. Ignoring.
462-
if (iterator->second == Dst)
472+
// Dst has a PHI node. We also need to create an alias output block.
473+
if (!hasPhiNode(Dst)) {
474+
Output.emplace_back(Src, Dst);
463475
continue;
476+
}
464477

465-
// The same Src block branches to 2 distinct blocks. This will be an
466-
// issue for the generated OpPhi. Creating alias block.
478+
// Dst already targeted AND contains a PHI node. We'll need alias
479+
// blocks.
467480
BasicBlock *NewSrc =
468-
BasicBlock::Create(F.getContext(), "new.exit.src", &F);
481+
BasicBlock::Create(F.getContext(), "phi.alias", &F);
469482
replaceBranchTargets(Src, Dst, NewSrc);
470-
replacePhiTargets(Dst, Src, NewSrc);
471-
483+
// replacePhiTargets(Dst, Src, NewSrc);
472484
IRBuilder<> Builder(NewSrc);
473485
Builder.CreateBr(Dst);
474-
475-
Seen.emplace(NewSrc, Dst);
476-
Output.emplace_back(NewSrc, Dst);
486+
Output.emplace_back(Src, NewSrc);
477487
}
478488

479489
return Output;
480490
}
481491

492+
AllocaInst *CreateVariable(Function &F, Type *Type,
493+
BasicBlock::iterator Position) {
494+
const DataLayout &DL = F.getDataLayout();
495+
return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
496+
Position);
497+
}
498+
482499
// Given a construct defined by |Header|, and a list of exiting edges
483500
// |Edges|, creates a new single exit node, fixing up those edges.
484501
BasicBlock *createSingleExitNode(BasicBlock *Header,
485502
std::vector<Edge> &Edges) {
486-
auto NewExit = BasicBlock::Create(F.getContext(), "new.exit", &F);
487-
IRBuilder<> ExitBuilder(NewExit);
488-
489-
std::vector<BasicBlock *> Dsts;
490-
std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
491-
492503
// Given 2 edges: Src1 -> Dst, Src2 -> Dst:
493504
// If Dst has an PHI node, and Src1 and Src2 are both operands, both Src1
494505
// and Src2 cannot be hidden by NewExit. Create 2 new nodes: Alias1,
495506
// Alias2 to which NewExit will branch before going to Dst. Then, patchup
496507
// Dst PHI node to look for Alias1 and Alias2.
497508
std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
498509

510+
std::vector<BasicBlock *> Dsts;
511+
std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
512+
auto NewExit = BasicBlock::Create(F.getContext(), "new.exit", &F);
513+
IRBuilder<> ExitBuilder(NewExit);
499514
for (auto &[Src, Dst] : FixedEdges) {
500515
if (DstToIndex.count(Dst) != 0)
501516
continue;
@@ -506,33 +521,38 @@ class SPIRVStructurizer : public FunctionPass {
506521
if (Dsts.size() == 1) {
507522
for (auto &[Src, Dst] : FixedEdges) {
508523
replaceBranchTargets(Src, Dst, NewExit);
509-
replacePhiTargets(Dst, Src, NewExit);
524+
// replacePhiTargets(Dst, Src, NewExit);
510525
}
511526
ExitBuilder.CreateBr(Dsts[0]);
512527
return NewExit;
513528
}
514529

515-
PHINode *PhiNode =
516-
ExitBuilder.CreatePHI(ExitBuilder.getInt32Ty(), FixedEdges.size());
530+
AllocaInst *Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
531+
F.begin()->getFirstInsertionPt());
532+
// PHINode *PhiNode = ExitBuilder.CreatePHI(ExitBuilder.getInt32Ty(),
533+
// FixedEdges.size());
517534

518535
for (auto &[Src, Dst] : FixedEdges) {
519-
PhiNode->addIncoming(DstToIndex[Dst], Src);
536+
IRBuilder<> B2(Src);
537+
B2.SetInsertPoint(Src->getFirstInsertionPt());
538+
B2.CreateStore(DstToIndex[Dst], Variable);
520539
replaceBranchTargets(Src, Dst, NewExit);
521-
replacePhiTargets(Dst, Src, NewExit);
522540
}
523541

542+
llvm::Value *Load =
543+
ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
544+
524545
// If we can avoid an OpSwitch, generate an OpBranch. Reason is some
525546
// OpBranch are allowed to exist without a new OpSelectionMerge if one of
526547
// the branch is the parent's merge node, while OpSwitches are not.
527548
if (Dsts.size() == 2) {
528-
Value *Condition = ExitBuilder.CreateCmp(CmpInst::ICMP_EQ,
529-
DstToIndex[Dsts[0]], PhiNode);
549+
Value *Condition =
550+
ExitBuilder.CreateCmp(CmpInst::ICMP_EQ, DstToIndex[Dsts[0]], Load);
530551
ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
531552
return NewExit;
532553
}
533554

534-
SwitchInst *Sw =
535-
ExitBuilder.CreateSwitch(PhiNode, Dsts[0], Dsts.size() - 1);
555+
SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
536556
for (auto It = Dsts.begin() + 1; It != Dsts.end(); ++It) {
537557
Sw->addCase(DstToIndex[*It], *It);
538558
}
@@ -576,7 +596,7 @@ class SPIRVStructurizer : public FunctionPass {
576596

577597
// Creates a new basic block in F with a single OpUnreachable instruction.
578598
BasicBlock *CreateUnreachable(Function &F) {
579-
BasicBlock *BB = BasicBlock::Create(F.getContext(), "new.exit", &F);
599+
BasicBlock *BB = BasicBlock::Create(F.getContext(), "unreachable", &F);
580600
IRBuilder<> Builder(BB);
581601
Builder.CreateUnreachable();
582602
return BB;
@@ -1127,6 +1147,18 @@ class SPIRVStructurizer : public FunctionPass {
11271147
continue;
11281148

11291149
Modified = true;
1150+
1151+
if (Merge == nullptr) {
1152+
Merge = *successors(Header).begin();
1153+
IRBuilder<> Builder(Header);
1154+
Builder.SetInsertPoint(Header->getTerminator());
1155+
1156+
auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
1157+
SmallVector<Value *, 1> Args = {MergeAddress};
1158+
Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
1159+
continue;
1160+
}
1161+
11301162
Instruction *SplitInstruction = Merge->getTerminator();
11311163
if (isMergeInstruction(SplitInstruction->getPrevNode()))
11321164
SplitInstruction = SplitInstruction->getPrevNode();

llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/MC/TargetRegistry.h"
3030
#include "llvm/Pass.h"
3131
#include "llvm/Target/TargetOptions.h"
32+
#include "llvm/Transforms/Scalar/Reg2Mem.h"
3233
#include "llvm/Transforms/Utils.h"
3334
#include <optional>
3435

@@ -162,20 +163,30 @@ void SPIRVPassConfig::addIRPasses() {
162163
TargetPassConfig::addIRPasses();
163164

164165
if (TM.getSubtargetImpl()->isVulkanEnv()) {
166+
addPass(createRegToMemWrapperPass());
167+
165168
// 1. Simplify loop for subsequent transformations. After this steps, loops
166169
// have the following properties:
167170
// - loops have a single entry edge (pre-header to loop header).
168171
// - all loop exits are dominated by the loop pre-header.
169172
// - loops have a single back-edge.
170173
addPass(createLoopSimplifyPass());
171174

172-
// 2. Merge the convergence region exit nodes into one. After this step,
175+
// 2. Removes registers whose lifetime spans across basic blocks. Also
176+
// removes phi nodes. This will greatly simplify the next steps.
177+
addPass(createRegToMemWrapperPass());
178+
179+
// 3. Merge the convergence region exit nodes into one. After this step,
173180
// regions are single-entry, single-exit. This will help determine the
174181
// correct merge block.
175182
addPass(createSPIRVMergeRegionExitTargetsPass());
176183

177-
// 3. Structurize.
184+
// 4. Structurize.
178185
addPass(createSPIRVStructurizerPass());
186+
187+
// 5. Reduce the amount of variables required by pushing some operations
188+
// back to virtual registers.
189+
addPass(createPromoteMemoryToRegisterPass());
179190
}
180191

181192
addPass(createSPIRVRegularizerPass());

0 commit comments

Comments
 (0)