Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVInstrInfo.cpp
SPIRVInstructionSelector.cpp
SPIRVStripConvergentIntrinsics.cpp
SPIRVMergeRegionExitTargets.cpp
SPIRVISelLowering.cpp
SPIRVLegalizerInfo.cpp
SPIRVMCInstLower.cpp
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class InstructionSelector;
class RegisterBankInfo;

ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
FunctionPass *createSPIRVRegularizerPass();
FunctionPass *createSPIRVPreLegalizerPass();
Expand Down
23 changes: 23 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ class SPIRVEmitIntrinsics
ModulePass::getAnalysisUsage(AU);
}
};

bool isConvergenceIntrinsic(const Instruction *I) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't necessary, because now we can simply do isa<ConvergenceControlInst>(I)

const auto *II = dyn_cast<IntrinsicInst>(I);
if (!II)
return false;

return II->getIntrinsicID() == Intrinsic::experimental_convergence_entry ||
II->getIntrinsicID() == Intrinsic::experimental_convergence_loop ||
II->getIntrinsicID() == Intrinsic::experimental_convergence_anchor;
}
} // namespace

char SPIRVEmitIntrinsics::ID = 0;
Expand Down Expand Up @@ -1074,6 +1084,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,

void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
IRBuilder<> &B) {
// Don't assign types to LLVM tokens.
if (isConvergenceIntrinsic(I))
return;

reportFatalOnTokenType(I);
if (!isPointerTy(I->getType()) || !requireAssignType(I) ||
isa<BitCastInst>(I))
Expand All @@ -1092,6 +1106,10 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,

void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
IRBuilder<> &B) {
// Don't assign types to LLVM tokens.
if (isConvergenceIntrinsic(I))
return;

reportFatalOnTokenType(I);
Type *Ty = I->getType();
if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) {
Expand Down Expand Up @@ -1319,6 +1337,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
I = visit(*I);
if (!I)
continue;

// Don't emit intrinsics for convergence operations.
if (isConvergenceIntrinsic(I))
continue;

processInstrAfterVisit(I, B);
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def OpFwidthCoarse: UnOp<"OpFwidthCoarse", 215>;
def OpPhi: Op<245, (outs ID:$res), (ins TYPE:$type, ID:$var0, ID:$block0, variable_ops),
"$res = OpPhi $type $var0 $block0">;
def OpLoopMerge: Op<246, (outs), (ins ID:$merge, ID:$continue, LoopControl:$lc, variable_ops),
"OpLoopMerge $merge $merge $continue $lc">;
"OpLoopMerge $merge $continue $lc">;
def OpSelectionMerge: Op<247, (outs), (ins ID:$merge, SelectionControl:$sc),
"OpSelectionMerge $merge $sc">;
def OpLabel: Op<248, (outs ID:$label), (ins), "$label = OpLabel">;
Expand Down
290 changes: 290 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
//===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Merge the multiple exit targets of a convergence region into a single block.
// Each exit target will be assigned a constant value, and a phi node + switch
// will allow the new exit target to re-route to the correct basic block.
//
//===----------------------------------------------------------------------===//

#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/CodeGen/IntrinsicLowering.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/InitializePasses.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"

using namespace llvm;

namespace llvm {
void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry &);
} // namespace llvm

namespace llvm {

class SPIRVMergeRegionExitTargets : public FunctionPass {
public:
static char ID;

SPIRVMergeRegionExitTargets() : FunctionPass(ID) {
initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
};

// Gather all the successors of |BB|.
// This function asserts if the terminator neither a branch, switch or return.
std::unordered_set<BasicBlock *> gatherSuccessors(BasicBlock *BB) {
std::unordered_set<BasicBlock *> output;
auto *T = BB->getTerminator();

if (auto *BI = dyn_cast<BranchInst>(T)) {
output.insert(BI->getSuccessor(0));
if (BI->isConditional())
output.insert(BI->getSuccessor(1));
return output;
}

if (auto *SI = dyn_cast<SwitchInst>(T)) {
output.insert(SI->getDefaultDest());
for (auto &Case : SI->cases()) {
output.insert(Case.getCaseSuccessor());
}
return output;
}

if (auto *RI = dyn_cast<ReturnInst>(T))
return output;

assert(false && "Unhandled terminator type.");
return output;
}

/// Create a value in BB set to the value associated with the branch the block
/// terminator will take.
llvm::Value *createExitVariable(
BasicBlock *BB,
const std::unordered_map<BasicBlock *, ConstantInt *> &TargetToValue) {
auto *T = BB->getTerminator();
if (auto *RI = dyn_cast<ReturnInst>(T)) {
return nullptr;
}

IRBuilder<> Builder(BB);
Builder.SetInsertPoint(T);

if (auto *BI = dyn_cast<BranchInst>(T)) {

BasicBlock *LHSTarget = BI->getSuccessor(0);
BasicBlock *RHSTarget =
BI->isConditional() ? BI->getSuccessor(1) : nullptr;

Value *LHS = TargetToValue.count(LHSTarget) != 0
? TargetToValue.at(LHSTarget)
: nullptr;
Value *RHS = TargetToValue.count(RHSTarget) != 0
? TargetToValue.at(RHSTarget)
: nullptr;

if (LHS == nullptr || RHS == nullptr)
return LHS == nullptr ? RHS : LHS;
return Builder.CreateSelect(BI->getCondition(), LHS, RHS);
}

// TODO: add support for switch cases.
assert(false && "Unhandled terminator type.");
}

/// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
void replaceBranchTargets(BasicBlock *BB,
const std::unordered_set<BasicBlock *> ToReplace,
BasicBlock *NewTarget) {
auto *T = BB->getTerminator();
if (auto *RI = dyn_cast<ReturnInst>(T))
return;

if (auto *BI = dyn_cast<BranchInst>(T)) {
for (size_t i = 0; i < BI->getNumSuccessors(); i++) {
if (ToReplace.count(BI->getSuccessor(i)) != 0)
BI->setSuccessor(i, NewTarget);
}
return;
}

if (auto *SI = dyn_cast<SwitchInst>(T)) {
for (size_t i = 0; i < SI->getNumSuccessors(); i++) {
if (ToReplace.count(SI->getSuccessor(i)) != 0)
SI->setSuccessor(i, NewTarget);
}
return;
}

assert(false && "Unhandled terminator type.");
}

// Run the pass on the given convergence region, ignoring the sub-regions.
// Returns true if the CFG changed, false otherwise.
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
const SPIRV::ConvergenceRegion *CR) {
// Gather all the exit targets for this region.
std::unordered_set<BasicBlock *> ExitTargets;
for (BasicBlock *Exit : CR->Exits) {
for (BasicBlock *Target : gatherSuccessors(Exit)) {
if (CR->Blocks.count(Target) == 0)
ExitTargets.insert(Target);
}
}

// If we have zero or one exit target, nothing do to.
if (ExitTargets.size() <= 1)
return false;

// Create the new single exit target.
auto F = CR->Entry->getParent();
auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
IRBuilder<> Builder(NewExitTarget);

// CodeGen output needs to be stable. Using the set as-is would order
// the targets differently depending on the allocation pattern.
// Sorting per basic-block ordering in the function.
std::vector<BasicBlock *> SortedExitTargets;
std::vector<BasicBlock *> SortedExits;
for (BasicBlock &BB : *F) {
if (ExitTargets.count(&BB) != 0)
SortedExitTargets.push_back(&BB);
if (CR->Exits.count(&BB) != 0)
SortedExits.push_back(&BB);
}

// Creating one constant per distinct exit target. This will be route to the
// correct target.
std::unordered_map<BasicBlock *, ConstantInt *> TargetToValue;
for (BasicBlock *Target : SortedExitTargets)
TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size()));

// Creating one variable per exit node, set to the constant matching the
// targeted external block.
std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
for (auto Exit : SortedExits) {
llvm::Value *Value = createExitVariable(Exit, TargetToValue);
ExitToVariable.emplace_back(std::make_pair(Exit, Value));
}

// Gather the correct value depending on the exit we came from.
llvm::PHINode *node =
Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
for (auto [BB, Value] : ExitToVariable) {
node->addIncoming(Value, BB);
}

// Creating the switch to jump to the correct exit target.
std::vector<std::pair<BasicBlock *, ConstantInt *>> CasesList(
TargetToValue.begin(), TargetToValue.end());
llvm::SwitchInst *Sw =
Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1);
for (size_t i = 1; i < CasesList.size(); i++)
Sw->addCase(CasesList[i].second, CasesList[i].first);

// Fix exit branches to redirect to the new exit.
for (auto Exit : CR->Exits)
replaceBranchTargets(Exit, ExitTargets, NewExitTarget);

return true;
}

/// Run the pass on the given convergence region and sub-regions (DFS).
/// Returns true if a region/sub-region was modified, false otherwise.
/// This returns as soon as one region/sub-region has been modified.
bool runOnConvergenceRegion(LoopInfo &LI,
const SPIRV::ConvergenceRegion *CR) {
for (auto *Child : CR->Children)
if (runOnConvergenceRegion(LI, Child))
return true;

return runOnConvergenceRegionNoRecurse(LI, CR);
}

#if !NDEBUG
/// Validates each edge exiting the region has the same destination basic
/// block.
void validateRegionExits(const SPIRV::ConvergenceRegion *CR) {
for (auto *Child : CR->Children)
validateRegionExits(Child);

std::unordered_set<BasicBlock *> ExitTargets;
for (auto *Exit : CR->Exits) {
auto Set = gatherSuccessors(Exit);
for (auto *BB : Set) {
if (CR->Blocks.count(BB) == 0)
ExitTargets.insert(BB);
}
}

assert(ExitTargets.size() <= 1);
}
#endif

virtual bool runOnFunction(Function &F) override {
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
const auto *TopLevelRegion =
getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
.getRegionInfo()
.getTopLevelRegion();

// FIXME: very inefficient method: each time a region is modified, we bubble
// back up, and recompute the whole convergence region tree. Once the
// algorithm is completed and test coverage good enough, rewrite this pass
// to be efficient instead of simple.
bool modified = false;
while (runOnConvergenceRegion(LI, TopLevelRegion)) {
TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
.getRegionInfo()
.getTopLevelRegion();
modified = true;
}

F.dump();
#if !NDEBUG
validateRegionExits(TopLevelRegion);
#endif
return modified;
}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
FunctionPass::getAnalysisUsage(AU);
}
};
} // namespace llvm

char SPIRVMergeRegionExitTargets::ID = 0;

INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
"SPIRV split region exit blocks", false, false)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass)

INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets, "split-region-exit-blocks",
"SPIRV split region exit blocks", false, false)

FunctionPass *llvm::createSPIRVMergeRegionExitTargetsPass() {
return new SPIRVMergeRegionExitTargets();
}
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ void SPIRVPassConfig::addIRPasses() {
// - all loop exits are dominated by the loop pre-header.
// - loops have a single back-edge.
addPass(createLoopSimplifyPass());
addPass(createSPIRVMergeRegionExitTargetsPass());
}

TargetPassConfig::addIRPasses();
Expand Down
Loading