Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ struct BPFOperand : public MCParsedAsmOperand {
.Case("callx", true)
.Case("goto", true)
.Case("gotol", true)
.Case("gotox", true)
.Case("may_goto", true)
.Case("*", true)
.Case("exit", true)
Expand Down
114 changes: 90 additions & 24 deletions llvm/lib/Target/BPF/BPFAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,52 +11,35 @@
//
//===----------------------------------------------------------------------===//

#include "BPFAsmPrinter.h"
#include "BPF.h"
#include "BPFInstrInfo.h"
#include "BPFMCInstLower.h"
#include "BTFDebug.h"
#include "MCTargetDesc/BPFInstPrinter.h"
#include "TargetInfo/BPFTargetInfo.h"
#include "llvm/BinaryFormat/ELF.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineConstantPool.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/IR/Module.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCStreamer.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/MC/MCSymbolELF.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetLoweringObjectFile.h"
using namespace llvm;

#define DEBUG_TYPE "asm-printer"

namespace {
class BPFAsmPrinter : public AsmPrinter {
public:
explicit BPFAsmPrinter(TargetMachine &TM,
std::unique_ptr<MCStreamer> Streamer)
: AsmPrinter(TM, std::move(Streamer), ID), BTF(nullptr) {}

StringRef getPassName() const override { return "BPF Assembly Printer"; }
bool doInitialization(Module &M) override;
void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
const char *ExtraCode, raw_ostream &O) override;
bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNum,
const char *ExtraCode, raw_ostream &O) override;

void emitInstruction(const MachineInstr *MI) override;

static char ID;

private:
BTFDebug *BTF;
};
} // namespace

bool BPFAsmPrinter::doInitialization(Module &M) {
AsmPrinter::doInitialization(M);

Expand All @@ -69,6 +52,45 @@ bool BPFAsmPrinter::doInitialization(Module &M) {
return false;
}

const BPFTargetMachine &BPFAsmPrinter::getBTM() const {
return static_cast<const BPFTargetMachine &>(TM);
}

bool BPFAsmPrinter::doFinalization(Module &M) {
// Remove unused globals which are previously used for jump table.
const BPFSubtarget *Subtarget = getBTM().getSubtargetImpl();
if (Subtarget->hasGotox()) {
std::vector<GlobalVariable *> Targets;
for (GlobalVariable &Global : M.globals()) {
if (Global.getLinkage() != GlobalValue::PrivateLinkage)
continue;
if (!Global.isConstant() || !Global.hasInitializer())
continue;

Constant *CV = dyn_cast<Constant>(Global.getInitializer());
if (!CV)
continue;
ConstantArray *CA = dyn_cast<ConstantArray>(CV);
if (!CA)
continue;

for (unsigned i = 1, e = CA->getNumOperands(); i != e; ++i) {
if (!dyn_cast<BlockAddress>(CA->getOperand(i)))
continue;
}
Targets.push_back(&Global);
}

for (GlobalVariable *GV : Targets) {
GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
GV->dropAllReferences();
GV->eraseFromParent();
}
}

return AsmPrinter::doFinalization(M);
}

void BPFAsmPrinter::printOperand(const MachineInstr *MI, int OpNum,
raw_ostream &O) {
const MachineOperand &MO = MI->getOperand(OpNum);
Expand Down Expand Up @@ -150,6 +172,50 @@ void BPFAsmPrinter::emitInstruction(const MachineInstr *MI) {
EmitToStreamer(*OutStreamer, TmpInst);
}

MCSymbol *BPFAsmPrinter::getJTPublicSymbol(unsigned JTI) {
SmallString<60> Name;
raw_svector_ostream(Name)
<< "BPF.JT." << MF->getFunctionNumber() << '.' << JTI;
MCSymbol *S = OutContext.getOrCreateSymbol(Name);
if (auto *ES = static_cast<MCSymbolELF *>(S)) {
ES->setBinding(ELF::STB_GLOBAL);
ES->setType(ELF::STT_OBJECT);
}
return S;
}

void BPFAsmPrinter::emitJumpTableInfo() {
const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
if (!MJTI)
return;

const std::vector<MachineJumpTableEntry> &JT = MJTI->getJumpTables();
if (JT.empty())
return;

const TargetLoweringObjectFile &TLOF = getObjFileLowering();
const Function &F = MF->getFunction();
MCSection *JTS = TLOF.getSectionForJumpTable(F, TM);
assert(MJTI->getEntryKind() == MachineJumpTableInfo::EK_BlockAddress);
unsigned EntrySize = MJTI->getEntrySize(getDataLayout());
OutStreamer->switchSection(JTS);
for (unsigned JTI = 0; JTI < JT.size(); JTI++) {
ArrayRef<MachineBasicBlock *> JTBBs = JT[JTI].MBBs;
if (JTBBs.empty())
continue;

MCSymbol *JTStart = getJTPublicSymbol(JTI);
OutStreamer->emitLabel(JTStart);
for (const MachineBasicBlock *MBB : JTBBs) {
const MCExpr *LHS = MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
OutStreamer->emitValue(LHS, EntrySize);
}
const MCExpr *JTSize =
MCConstantExpr::create(JTBBs.size() * EntrySize, OutContext);
OutStreamer->emitELFSize(JTStart, JTSize);
}
}

char BPFAsmPrinter::ID = 0;

INITIALIZE_PASS(BPFAsmPrinter, "bpf-asm-printer", "BPF Assembly Printer", false,
Expand Down
48 changes: 48 additions & 0 deletions llvm/lib/Target/BPF/BPFAsmPrinter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//===-- BPFFrameLowering.h - Define frame lowering for BPF -----*- C++ -*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_BPF_BPFASMPRINTER_H
#define LLVM_LIB_TARGET_BPF_BPFASMPRINTER_H

#include "BPFTargetMachine.h"
#include "BTFDebug.h"
#include "llvm/CodeGen/AsmPrinter.h"

namespace llvm {

class BPFAsmPrinter : public AsmPrinter {
public:
explicit BPFAsmPrinter(TargetMachine &TM,
std::unique_ptr<MCStreamer> Streamer)
: AsmPrinter(TM, std::move(Streamer), ID), BTF(nullptr), TM(TM) {}

StringRef getPassName() const override { return "BPF Assembly Printer"; }
bool doInitialization(Module &M) override;
bool doFinalization(Module &M) override;
void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O);
bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
const char *ExtraCode, raw_ostream &O) override;
bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNum,
const char *ExtraCode, raw_ostream &O) override;

void emitInstruction(const MachineInstr *MI) override;
MCSymbol *getJTPublicSymbol(unsigned JTI);
virtual void emitJumpTableInfo() override;

static char ID;

private:
BTFDebug *BTF;
TargetMachine &TM;

const BPFTargetMachine &getBTM() const;
};

} // namespace llvm

#endif /* LLVM_LIB_TARGET_BPF_BPFASMPRINTER_H */
Loading