Skip to content

[GlobalISel] Reduce KnownBits usage in matcher combines #92381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 67 additions & 35 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3220,8 +3220,15 @@ bool CombinerHelper::matchRedundantAnd(MachineInstr &MI,
Register AndDst = MI.getOperand(0).getReg();
Register LHS = MI.getOperand(1).getReg();
Register RHS = MI.getOperand(2).getReg();
KnownBits LHSBits = KB->getKnownBits(LHS);

// Check the RHS (maybe a constant) first, and if we have no KnownBits there,
// we can't do anything. If we do, then it depends on whether we have
// KnownBits on the LHS.
KnownBits RHSBits = KB->getKnownBits(RHS);
if (RHSBits.isUnknown())
return false;

KnownBits LHSBits = KB->getKnownBits(LHS);

// Check that x & Mask == x.
// x & 1 == x, always
Expand Down Expand Up @@ -3260,6 +3267,7 @@ bool CombinerHelper::matchRedundantOr(MachineInstr &MI, Register &Replacement) {
Register OrDst = MI.getOperand(0).getReg();
Register LHS = MI.getOperand(1).getReg();
Register RHS = MI.getOperand(2).getReg();

KnownBits LHSBits = KB->getKnownBits(LHS);
KnownBits RHSBits = KB->getKnownBits(RHS);

Expand Down Expand Up @@ -4253,43 +4261,67 @@ bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
int64_t &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_ICMP);
auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
auto KnownLHS = KB->getKnownBits(MI.getOperand(2).getReg());

// We want to avoid calling KnownBits on the LHS if possible, as this combine
// has no filter and runs on every G_ICMP instruction. We can avoid calling
// KnownBits on the LHS in two cases:
//
// - The RHS is unknown: Constants are always on RHS. If the RHS is unknown
// we cannot do any transforms so we can safely bail out early.
// - The RHS is zero: we don't need to know the LHS to do unsigned <0 and
// >=0.
auto KnownRHS = KB->getKnownBits(MI.getOperand(3).getReg());
if (KnownRHS.isUnknown())
return false;

std::optional<bool> KnownVal;
switch (Pred) {
default:
llvm_unreachable("Unexpected G_ICMP predicate?");
case CmpInst::ICMP_EQ:
KnownVal = KnownBits::eq(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_NE:
KnownVal = KnownBits::ne(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_SGE:
KnownVal = KnownBits::sge(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_SGT:
KnownVal = KnownBits::sgt(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_SLE:
KnownVal = KnownBits::sle(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_SLT:
KnownVal = KnownBits::slt(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_UGE:
KnownVal = KnownBits::uge(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_UGT:
KnownVal = KnownBits::ugt(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_ULE:
KnownVal = KnownBits::ule(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_ULT:
KnownVal = KnownBits::ult(KnownLHS, KnownRHS);
break;
if (KnownRHS.isZero()) {
// ? uge 0 -> always true
// ? ult 0 -> always false
if (Pred == CmpInst::ICMP_UGE)
KnownVal = true;
else if (Pred == CmpInst::ICMP_ULT)
KnownVal = false;
}

if (!KnownVal) {
auto KnownLHS = KB->getKnownBits(MI.getOperand(2).getReg());
switch (Pred) {
default:
llvm_unreachable("Unexpected G_ICMP predicate?");
case CmpInst::ICMP_EQ:
KnownVal = KnownBits::eq(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_NE:
KnownVal = KnownBits::ne(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_SGE:
KnownVal = KnownBits::sge(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_SGT:
KnownVal = KnownBits::sgt(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_SLE:
KnownVal = KnownBits::sle(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_SLT:
KnownVal = KnownBits::slt(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_UGE:
KnownVal = KnownBits::uge(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_UGT:
KnownVal = KnownBits::ugt(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_ULE:
KnownVal = KnownBits::ule(KnownLHS, KnownRHS);
break;
case CmpInst::ICMP_ULT:
KnownVal = KnownBits::ult(KnownLHS, KnownRHS);
break;
}
}

if (!KnownVal)
return false;
MatchInfo =
Expand Down
Loading