Skip to content

[CodeGen] change prototype of regalloc filter function #93525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions llvm/include/llvm/CodeGen/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,20 @@ namespace llvm {
/// possible. It is best suited for debug code where live ranges are short.
///
FunctionPass *createFastRegisterAllocator();
FunctionPass *createFastRegisterAllocator(RegClassFilterFunc F,
FunctionPass *createFastRegisterAllocator(RegAllocFilterFunc F,
bool ClearVirtRegs);

/// BasicRegisterAllocation Pass - This pass implements a degenerate global
/// register allocator using the basic regalloc framework.
///
FunctionPass *createBasicRegisterAllocator();
FunctionPass *createBasicRegisterAllocator(RegClassFilterFunc F);
FunctionPass *createBasicRegisterAllocator(RegAllocFilterFunc F);

/// Greedy register allocation pass - This pass implements a global register
/// allocator for optimized builds.
///
FunctionPass *createGreedyRegisterAllocator();
FunctionPass *createGreedyRegisterAllocator(RegClassFilterFunc F);
FunctionPass *createGreedyRegisterAllocator(RegAllocFilterFunc F);

/// PBQPRegisterAllocation Pass - This pass implements the Partitioned Boolean
/// Quadratic Prograaming (PBQP) based register allocator.
Expand Down
6 changes: 4 additions & 2 deletions llvm/include/llvm/CodeGen/RegAllocCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
#ifndef LLVM_CODEGEN_REGALLOCCOMMON_H
#define LLVM_CODEGEN_REGALLOCCOMMON_H

#include "llvm/CodeGen/Register.h"
#include <functional>

namespace llvm {

class TargetRegisterClass;
class TargetRegisterInfo;
class MachineRegisterInfo;

/// Filter function for register classes during regalloc. Default register class
/// filter is nullptr, where all registers should be allocated.
typedef std::function<bool(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC)>
RegClassFilterFunc;
const MachineRegisterInfo &MRI, const Register Reg)>
RegAllocFilterFunc;
}

#endif // LLVM_CODEGEN_REGALLOCCOMMON_H
2 changes: 1 addition & 1 deletion llvm/include/llvm/CodeGen/RegAllocFast.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace llvm {

struct RegAllocFastPassOptions {
RegClassFilterFunc Filter = nullptr;
RegAllocFilterFunc Filter = nullptr;
StringRef FilterName = "all";
bool ClearVRegs = true;
};
Expand Down
10 changes: 5 additions & 5 deletions llvm/include/llvm/Passes/PassBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,9 @@ class PassBuilder {
/// returns false.
Error parseAAPipeline(AAManager &AA, StringRef PipelineText);

/// Parse RegClassFilterName to get RegClassFilterFunc.
std::optional<RegClassFilterFunc>
parseRegAllocFilter(StringRef RegClassFilterName);
/// Parse RegAllocFilterName to get RegAllocFilterFunc.
std::optional<RegAllocFilterFunc>
parseRegAllocFilter(StringRef RegAllocFilterName);

/// Print pass names.
void printPassNames(raw_ostream &OS);
Expand Down Expand Up @@ -586,7 +586,7 @@ class PassBuilder {
/// needs it. E.g. AMDGPU requires regalloc passes can handle sgpr and vgpr
/// separately.
void registerRegClassFilterParsingCallback(
const std::function<RegClassFilterFunc(StringRef)> &C) {
const std::function<RegAllocFilterFunc(StringRef)> &C) {
RegClassFilterParsingCallbacks.push_back(C);
}

Expand Down Expand Up @@ -807,7 +807,7 @@ class PassBuilder {
2>
MachineFunctionPipelineParsingCallbacks;
// Callbacks to parse `filter` parameter in register allocation passes
SmallVector<std::function<RegClassFilterFunc(StringRef)>, 2>
SmallVector<std::function<RegAllocFilterFunc(StringRef)>, 2>
RegClassFilterParsingCallbacks;
};

Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/CodeGen/RegAllocBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class RegAllocBase {

private:
/// Private, callees should go through shouldAllocateRegister
const RegClassFilterFunc ShouldAllocateClass;
const RegAllocFilterFunc shouldAllocateRegisterImpl;

protected:
/// Inst which is a def of an original reg and whose defs are already all
Expand All @@ -81,7 +81,8 @@ class RegAllocBase {
/// always available for the remat of all the siblings of the original reg.
SmallPtrSet<MachineInstr *, 32> DeadRemats;

RegAllocBase(const RegClassFilterFunc F = nullptr) : ShouldAllocateClass(F) {}
RegAllocBase(const RegAllocFilterFunc F = nullptr)
: shouldAllocateRegisterImpl(F) {}

virtual ~RegAllocBase() = default;

Expand All @@ -90,9 +91,9 @@ class RegAllocBase {

/// Get whether a given register should be allocated
bool shouldAllocateRegister(Register Reg) {
if (!ShouldAllocateClass)
if (!shouldAllocateRegisterImpl)
return true;
return ShouldAllocateClass(*TRI, *MRI->getRegClass(Reg));
return shouldAllocateRegisterImpl(*TRI, *MRI, Reg);
}

// The top-level driver. The output is a VirtRegMap that us updated with
Expand Down
10 changes: 4 additions & 6 deletions llvm/lib/CodeGen/RegAllocBasic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class RABasic : public MachineFunctionPass,
void LRE_WillShrinkVirtReg(Register) override;

public:
RABasic(const RegClassFilterFunc F = nullptr);
RABasic(const RegAllocFilterFunc F = nullptr);

/// Return the pass name.
StringRef getPassName() const override { return "Basic Register Allocator"; }
Expand Down Expand Up @@ -168,10 +168,8 @@ void RABasic::LRE_WillShrinkVirtReg(Register VirtReg) {
enqueue(&LI);
}

RABasic::RABasic(RegClassFilterFunc F):
MachineFunctionPass(ID),
RegAllocBase(F) {
}
RABasic::RABasic(RegAllocFilterFunc F)
: MachineFunctionPass(ID), RegAllocBase(F) {}

void RABasic::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
Expand Down Expand Up @@ -333,6 +331,6 @@ FunctionPass* llvm::createBasicRegisterAllocator() {
return new RABasic();
}

FunctionPass* llvm::createBasicRegisterAllocator(RegClassFilterFunc F) {
FunctionPass *llvm::createBasicRegisterAllocator(RegAllocFilterFunc F) {
return new RABasic(F);
}
16 changes: 8 additions & 8 deletions llvm/lib/CodeGen/RegAllocFast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ class InstrPosIndexes {

class RegAllocFastImpl {
public:
RegAllocFastImpl(const RegClassFilterFunc F = nullptr,
RegAllocFastImpl(const RegAllocFilterFunc F = nullptr,
bool ClearVirtRegs_ = true)
: ShouldAllocateClass(F), StackSlotForVirtReg(-1),
: ShouldAllocateRegisterImpl(F), StackSlotForVirtReg(-1),
ClearVirtRegs(ClearVirtRegs_) {}

private:
Expand All @@ -188,7 +188,7 @@ class RegAllocFastImpl {
const TargetRegisterInfo *TRI = nullptr;
const TargetInstrInfo *TII = nullptr;
RegisterClassInfo RegClassInfo;
const RegClassFilterFunc ShouldAllocateClass;
const RegAllocFilterFunc ShouldAllocateRegisterImpl;

/// Basic block currently being allocated.
MachineBasicBlock *MBB = nullptr;
Expand Down Expand Up @@ -397,7 +397,7 @@ class RegAllocFast : public MachineFunctionPass {
public:
static char ID;

RegAllocFast(const RegClassFilterFunc F = nullptr, bool ClearVirtRegs_ = true)
RegAllocFast(const RegAllocFilterFunc F = nullptr, bool ClearVirtRegs_ = true)
: MachineFunctionPass(ID), Impl(F, ClearVirtRegs_) {}

bool runOnMachineFunction(MachineFunction &MF) override {
Expand Down Expand Up @@ -440,10 +440,10 @@ INITIALIZE_PASS(RegAllocFast, "regallocfast", "Fast Register Allocator", false,

bool RegAllocFastImpl::shouldAllocateRegister(const Register Reg) const {
assert(Reg.isVirtual());
if (!ShouldAllocateClass)
if (!ShouldAllocateRegisterImpl)
return true;
const TargetRegisterClass &RC = *MRI->getRegClass(Reg);
return ShouldAllocateClass(*TRI, RC);

return ShouldAllocateRegisterImpl(*TRI, *MRI, Reg);
}

void RegAllocFastImpl::setPhysRegState(MCPhysReg PhysReg, unsigned NewState) {
Expand Down Expand Up @@ -1841,7 +1841,7 @@ void RegAllocFastPass::printPipeline(

FunctionPass *llvm::createFastRegisterAllocator() { return new RegAllocFast(); }

FunctionPass *llvm::createFastRegisterAllocator(RegClassFilterFunc Ftor,
FunctionPass *llvm::createFastRegisterAllocator(RegAllocFilterFunc Ftor,
bool ClearVirtRegs) {
return new RegAllocFast(Ftor, ClearVirtRegs);
}
10 changes: 4 additions & 6 deletions llvm/lib/CodeGen/RegAllocGreedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,12 @@ FunctionPass* llvm::createGreedyRegisterAllocator() {
return new RAGreedy();
}

FunctionPass *llvm::createGreedyRegisterAllocator(RegClassFilterFunc Ftor) {
FunctionPass *llvm::createGreedyRegisterAllocator(RegAllocFilterFunc Ftor) {
return new RAGreedy(Ftor);
}

RAGreedy::RAGreedy(RegClassFilterFunc F):
MachineFunctionPass(ID),
RegAllocBase(F) {
}
RAGreedy::RAGreedy(RegAllocFilterFunc F)
: MachineFunctionPass(ID), RegAllocBase(F) {}

void RAGreedy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
Expand Down Expand Up @@ -2306,7 +2304,7 @@ void RAGreedy::tryHintRecoloring(const LiveInterval &VirtReg) {
if (Reg.isPhysical())
continue;

// This may be a skipped class
// This may be a skipped register.
if (!VRM->hasPhys(Reg)) {
assert(!shouldAllocateRegister(Reg) &&
"We have an unallocated variable which should have been handled");
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/RegAllocGreedy.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class LLVM_LIBRARY_VISIBILITY RAGreedy : public MachineFunctionPass,
bool ReverseLocalAssignment = false;

public:
RAGreedy(const RegClassFilterFunc F = nullptr);
RAGreedy(const RegAllocFilterFunc F = nullptr);

/// Return the pass name.
StringRef getPassName() const override { return "Greedy Register Allocator"; }
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ parseRegAllocFastPassOptions(PassBuilder &PB, StringRef Params) {
std::tie(ParamName, Params) = Params.split(';');

if (ParamName.consume_front("filter=")) {
std::optional<RegClassFilterFunc> Filter =
std::optional<RegAllocFilterFunc> Filter =
PB.parseRegAllocFilter(ParamName);
if (!Filter) {
return make_error<StringError>(
Expand Down Expand Up @@ -2190,7 +2190,7 @@ Error PassBuilder::parseAAPipeline(AAManager &AA, StringRef PipelineText) {
return Error::success();
}

std::optional<RegClassFilterFunc>
std::optional<RegAllocFilterFunc>
PassBuilder::parseRegAllocFilter(StringRef FilterName) {
if (FilterName == "all")
return nullptr;
Expand Down
15 changes: 9 additions & 6 deletions llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,19 @@ class VGPRRegisterRegAlloc : public RegisterRegAllocBase<VGPRRegisterRegAlloc> {
};

static bool onlyAllocateSGPRs(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC) {
return static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(&RC);
const MachineRegisterInfo &MRI,
const Register Reg) {
const TargetRegisterClass *RC = MRI.getRegClass(Reg);
return static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(RC);
}

static bool onlyAllocateVGPRs(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC) {
return !static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(&RC);
const MachineRegisterInfo &MRI,
const Register Reg) {
const TargetRegisterClass *RC = MRI.getRegClass(Reg);
return !static_cast<const SIRegisterInfo &>(TRI).isSGPRClass(RC);
}


/// -{sgpr|vgpr}-regalloc=... command line option.
static FunctionPass *useDefaultRegisterAllocator() { return nullptr; }

Expand Down Expand Up @@ -749,7 +752,7 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
});

PB.registerRegClassFilterParsingCallback(
[](StringRef FilterName) -> RegClassFilterFunc {
[](StringRef FilterName) -> RegAllocFilterFunc {
if (FilterName == "sgpr")
return onlyAllocateSGPRs;
if (FilterName == "vgpr")
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,10 @@ class RVVRegisterRegAlloc : public RegisterRegAllocBase<RVVRegisterRegAlloc> {
};

static bool onlyAllocateRVVReg(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC) {
return RISCVRegisterInfo::isRVVRegClass(&RC);
const MachineRegisterInfo &MRI,
const Register Reg) {
const TargetRegisterClass *RC = MRI.getRegClass(Reg);
return RISCVRegisterInfo::isRVVRegClass(RC);
}

static FunctionPass *useDefaultRegisterAllocator() { return nullptr; }
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/X86/X86TargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,10 @@ std::unique_ptr<CSEConfigBase> X86PassConfig::getCSEConfig() const {
}

static bool onlyAllocateTileRegisters(const TargetRegisterInfo &TRI,
const TargetRegisterClass &RC) {
return static_cast<const X86RegisterInfo &>(TRI).isTileRegisterClass(&RC);
const MachineRegisterInfo &MRI,
const Register Reg) {
const TargetRegisterClass *RC = MRI.getRegClass(Reg);
return static_cast<const X86RegisterInfo &>(TRI).isTileRegisterClass(RC);
}

bool X86PassConfig::addRegAssignAndRewriteOptimized() {
Expand Down
Loading