diff --git a/bolt/include/bolt/Core/MCPlusBuilder.h b/bolt/include/bolt/Core/MCPlusBuilder.h index b233452985502..a8a3a58dba836 100644 --- a/bolt/include/bolt/Core/MCPlusBuilder.h +++ b/bolt/include/bolt/Core/MCPlusBuilder.h @@ -511,6 +511,11 @@ class MCPlusBuilder { llvm_unreachable("not implemented"); } + virtual void createDirectBranch(MCInst &Inst, const MCSymbol *Target, + MCContext *Ctx) { + llvm_unreachable("not implemented"); + } + virtual MCPhysReg getX86R11() const { llvm_unreachable("not implemented"); } virtual unsigned getShortBranchOpcode(unsigned Opcode) const { diff --git a/bolt/lib/Passes/Instrumentation.cpp b/bolt/lib/Passes/Instrumentation.cpp index fbf889279f1c0..e84acd00da369 100644 --- a/bolt/lib/Passes/Instrumentation.cpp +++ b/bolt/lib/Passes/Instrumentation.cpp @@ -293,9 +293,12 @@ void Instrumentation::instrumentIndirectTarget(BinaryBasicBlock &BB, BinaryBasicBlock::iterator &Iter, BinaryFunction &FromFunction, uint32_t From) { - auto L = FromFunction.getBinaryContext().scopeLock(); - const size_t IndCallSiteID = Summary->IndCallDescriptions.size(); - createIndCallDescription(FromFunction, From); + size_t IndCallSiteID; + { + auto L = FromFunction.getBinaryContext().scopeLock(); + IndCallSiteID = Summary->IndCallDescriptions.size(); + createIndCallDescription(FromFunction, From); + } BinaryContext &BC = FromFunction.getBinaryContext(); bool IsTailCall = BC.MIB->isTailCall(*Iter); @@ -305,9 +308,12 @@ void Instrumentation::instrumentIndirectTarget(BinaryBasicBlock &BB, : IndCallHandlerExitBBFunction->getSymbol(), IndCallSiteID, &*BC.Ctx); - Iter = BB.eraseInstruction(Iter); - Iter = insertInstructions(CounterInstrs, BB, Iter); - --Iter; + if (!BC.isAArch64()) { + Iter = BB.eraseInstruction(Iter); + Iter = insertInstructions(CounterInstrs, BB, Iter); + --Iter; + } else + Iter = insertInstructions(CounterInstrs, BB, Iter); } bool Instrumentation::instrumentOneTarget( diff --git a/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp b/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp index 9d5a578cfbdff..4895df33bec81 100644 --- a/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp +++ b/bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp @@ -1966,6 +1966,15 @@ class AArch64MCPlusBuilder : public MCPlusBuilder { convertJmpToTailCall(Inst); } + void createDirectBranch(MCInst &Inst, const MCSymbol *Target, + MCContext *Ctx) override { + Inst.setOpcode(AArch64::B); + Inst.clear(); + Inst.addOperand(MCOperand::createExpr(getTargetExprFor( + Inst, MCSymbolRefExpr::create(Target, MCSymbolRefExpr::VK_None, *Ctx), + *Ctx, 0))); + } + bool analyzeBranch(InstructionIterator Begin, InstructionIterator End, const MCSymbol *&TBB, const MCSymbol *&FBB, MCInst *&CondBranch, @@ -2328,21 +2337,26 @@ class AArch64MCPlusBuilder : public MCPlusBuilder { } InstructionListType createInstrumentedIndCallHandlerExitBB() const override { - InstructionListType Insts(5); // Code sequence for instrumented indirect call handler: + // ldr x1, [sp, #16] // msr nzcv, x1 // ldp x0, x1, [sp], #16 - // ldr x16, [sp], #16 - // ldp x0, x1, [sp], #16 - // br x16 - setSystemFlag(Insts[0], AArch64::X1); - createPopRegisters(Insts[1], AArch64::X0, AArch64::X1); - // Here we load address of the next function which should be called in the - // original binary to X16 register. Writing to X16 is permitted without - // needing to restore. - loadReg(Insts[2], AArch64::X16, AArch64::SP); - createPopRegisters(Insts[3], AArch64::X0, AArch64::X1); - createIndirectBranch(Insts[4], AArch64::X16, 0); + // ret + + InstructionListType Insts; + + Insts.emplace_back(); + loadReg(Insts.back(), AArch64::X1, AArch64::SP); + + Insts.emplace_back(); + setSystemFlag(Insts.back(), AArch64::X1); + + Insts.emplace_back(); + createPopRegisters(Insts.back(), AArch64::X0, AArch64::X1); + + Insts.emplace_back(); + createReturn(Insts.back()); + return Insts; } @@ -2418,39 +2432,69 @@ class AArch64MCPlusBuilder : public MCPlusBuilder { MCSymbol *HandlerFuncAddr, int CallSiteID, MCContext *Ctx) override { - InstructionListType Insts; // Code sequence used to enter indirect call instrumentation helper: - // stp x0, x1, [sp, #-16]! createPushRegisters + // stp x0, x1, [sp, #-16]! createPushRegisters (1) // mov target x0 convertIndirectCallToLoad -> orr x0 target xzr // mov x1 CallSiteID createLoadImmediate -> // movk x1, #0x0, lsl #48 // movk x1, #0x0, lsl #32 // movk x1, #0x0, lsl #16 // movk x1, #0x0 - // stp x0, x1, [sp, #-16]! - // bl *HandlerFuncAddr createIndirectCall -> + // stp x0, x1, [sp, #-16]! (2) // adr x0 *HandlerFuncAddr -> adrp + add - // blr x0 + // str x30, [sp, #-16]! (3) + // blr x0 (__bolt_instr_ind_call_handler_func) + // ldr x30, sp, #16 (3) + // ldp x0, x1, [sp], #16 (2) + // mov x0, x0 ; move target address to used register + // ldp x0, x1, [sp], #16 (1) + + InstructionListType Insts; Insts.emplace_back(); - createPushRegisters(Insts.back(), AArch64::X0, AArch64::X1); + createPushRegisters(Insts.back(), getIntArgRegister(0), + getIntArgRegister(1)); Insts.emplace_back(CallInst); - convertIndirectCallToLoad(Insts.back(), AArch64::X0); + convertIndirectCallToLoad(Insts.back(), getIntArgRegister(0)); InstructionListType LoadImm = createLoadImmediate(getIntArgRegister(1), CallSiteID); Insts.insert(Insts.end(), LoadImm.begin(), LoadImm.end()); Insts.emplace_back(); - createPushRegisters(Insts.back(), AArch64::X0, AArch64::X1); + createPushRegisters(Insts.back(), getIntArgRegister(0), + getIntArgRegister(1)); Insts.resize(Insts.size() + 2); - InstructionListType Addr = - materializeAddress(HandlerFuncAddr, Ctx, AArch64::X0); + InstructionListType Addr = materializeAddress( + HandlerFuncAddr, Ctx, CallInst.getOperand(0).getReg()); assert(Addr.size() == 2 && "Invalid Addr size"); std::copy(Addr.begin(), Addr.end(), Insts.end() - Addr.size()); + + Insts.emplace_back(); + storeReg(Insts.back(), AArch64::LR, getSpRegister(/*Size*/ 8)); + + Insts.emplace_back(); + createIndirectCallInst(Insts.back(), false, + CallInst.getOperand(0).getReg()); + Insts.emplace_back(); - createIndirectCallInst(Insts.back(), isTailCall(CallInst), AArch64::X0); + loadReg(Insts.back(), AArch64::LR, getSpRegister(/*Size*/ 8)); - // Carry over metadata including tail call marker if present. - stripAnnotations(Insts.back()); - moveAnnotations(std::move(CallInst), Insts.back()); + Insts.emplace_back(); + createPopRegisters(Insts.back(), getIntArgRegister(0), + getIntArgRegister(1)); + + // move x0 to indirect call register + Insts.emplace_back(); + Insts.back().setOpcode(AArch64::ORRXrs); + Insts.back().insert(Insts.back().begin(), + MCOperand::createReg(CallInst.getOperand(0).getReg())); + Insts.back().insert(Insts.back().begin() + 1, + MCOperand::createReg(AArch64::XZR)); + Insts.back().insert(Insts.back().begin() + 2, + MCOperand::createReg(getIntArgRegister(0))); + Insts.back().insert(Insts.back().begin() + 3, MCOperand::createImm(0)); + + Insts.emplace_back(); + createPopRegisters(Insts.back(), getIntArgRegister(0), + getIntArgRegister(1)); return Insts; } @@ -2472,30 +2516,44 @@ class AArch64MCPlusBuilder : public MCPlusBuilder { // ldr x30, [sp], #16 // b IndCallHandler InstructionListType Insts; + Insts.emplace_back(); - createPushRegisters(Insts.back(), AArch64::X0, AArch64::X1); + createPushRegisters(Insts.back(), getIntArgRegister(0), + getIntArgRegister(1)); + Insts.emplace_back(); getSystemFlag(Insts.back(), getIntArgRegister(1)); + + Insts.emplace_back(); + storeReg(Insts.back(), getIntArgRegister(1), getSpRegister(/*Size*/ 8)); + Insts.emplace_back(); Insts.emplace_back(); InstructionListType Addr = - materializeAddress(InstrTrampoline, Ctx, AArch64::X0); + materializeAddress(InstrTrampoline, Ctx, getIntArgRegister(0)); std::copy(Addr.begin(), Addr.end(), Insts.end() - Addr.size()); assert(Addr.size() == 2 && "Invalid Addr size"); + Insts.emplace_back(); - loadReg(Insts.back(), AArch64::X0, AArch64::X0); + loadReg(Insts.back(), getIntArgRegister(0), getIntArgRegister(0)); + InstructionListType cmpJmp = - createCmpJE(AArch64::X0, 0, IndCallHandler, Ctx); + createCmpJE(getIntArgRegister(0), 0, IndCallHandler, Ctx); Insts.insert(Insts.end(), cmpJmp.begin(), cmpJmp.end()); + Insts.emplace_back(); - storeReg(Insts.back(), AArch64::LR, AArch64::SP); + storeReg(Insts.back(), AArch64::LR, getSpRegister(/*Size*/ 8)); + Insts.emplace_back(); Insts.back().setOpcode(AArch64::BLR); - Insts.back().addOperand(MCOperand::createReg(AArch64::X0)); + Insts.back().addOperand(MCOperand::createReg(getIntArgRegister(0))); + Insts.emplace_back(); - loadReg(Insts.back(), AArch64::LR, AArch64::SP); + loadReg(Insts.back(), AArch64::LR, getSpRegister(/*Size*/ 8)); + Insts.emplace_back(); - createDirectCall(Insts.back(), IndCallHandler, Ctx, /*IsTailCall*/ true); + createDirectBranch(Insts.back(), IndCallHandler, Ctx); + return Insts; } diff --git a/bolt/runtime/instr.cpp b/bolt/runtime/instr.cpp index ae356e71cbe41..a174b982cbb84 100644 --- a/bolt/runtime/instr.cpp +++ b/bolt/runtime/instr.cpp @@ -1668,7 +1668,7 @@ extern "C" __attribute((naked)) void __bolt_instr_indirect_call() #if defined(__aarch64__) // clang-format off __asm__ __volatile__(SAVE_ALL - "ldp x0, x1, [sp, #288]\n" + "ldp x0, x1, [sp, #320]\n" "bl instrumentIndirectCall\n" RESTORE_ALL "ret\n" @@ -1705,7 +1705,7 @@ extern "C" __attribute((naked)) void __bolt_instr_indirect_tailcall() #if defined(__aarch64__) // clang-format off __asm__ __volatile__(SAVE_ALL - "ldp x0, x1, [sp, #288]\n" + "ldp x0, x1, [sp, #320]\n" "bl instrumentIndirectCall\n" RESTORE_ALL "ret\n"