Skip to content

Commit 7b91475

Browse files
[AArch64] Optimise test of the LSB of a paired whileCC insntruction
1 parent 17fe202 commit 7b91475

File tree

7 files changed

+123
-82
lines changed

7 files changed

+123
-82
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2663,6 +2663,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26632663
MAKE_CASE(AArch64ISD::INSR)
26642664
MAKE_CASE(AArch64ISD::PTEST)
26652665
MAKE_CASE(AArch64ISD::PTEST_ANY)
2666+
MAKE_CASE(AArch64ISD::PTEST_FIRST)
26662667
MAKE_CASE(AArch64ISD::PTRUE)
26672668
MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
26682669
MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
@@ -18445,21 +18446,41 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
1844518446
AArch64CC::CondCode Cond);
1844618447

1844718448
static bool isPredicateCCSettingOp(SDValue N) {
18448-
if ((N.getOpcode() == ISD::SETCC) ||
18449-
(N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
18450-
(N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
18451-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
18452-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
18453-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
18454-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
18455-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
18456-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
18457-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
18458-
// get_active_lane_mask is lowered to a whilelo instruction.
18459-
N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
18449+
if (N.getOpcode() == ISD::SETCC)
1846018450
return true;
1846118451

18462-
return false;
18452+
if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
18453+
isNullConstant(N.getOperand(1)))
18454+
N = N.getOperand(0);
18455+
18456+
if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
18457+
return false;
18458+
18459+
switch (N.getConstantOperandVal(0)) {
18460+
default:
18461+
return false;
18462+
case Intrinsic::aarch64_sve_whilege_x2:
18463+
case Intrinsic::aarch64_sve_whilegt_x2:
18464+
case Intrinsic::aarch64_sve_whilehi_x2:
18465+
case Intrinsic::aarch64_sve_whilehs_x2:
18466+
case Intrinsic::aarch64_sve_whilele_x2:
18467+
case Intrinsic::aarch64_sve_whilelo_x2:
18468+
case Intrinsic::aarch64_sve_whilels_x2:
18469+
case Intrinsic::aarch64_sve_whilelt_x2:
18470+
if (N.getResNo() != 0)
18471+
return false;
18472+
[[fallthrough]];
18473+
case Intrinsic::aarch64_sve_whilege:
18474+
case Intrinsic::aarch64_sve_whilegt:
18475+
case Intrinsic::aarch64_sve_whilehi:
18476+
case Intrinsic::aarch64_sve_whilehs:
18477+
case Intrinsic::aarch64_sve_whilele:
18478+
case Intrinsic::aarch64_sve_whilelo:
18479+
case Intrinsic::aarch64_sve_whilels:
18480+
case Intrinsic::aarch64_sve_whilelt:
18481+
case Intrinsic::get_active_lane_mask:
18482+
return true;
18483+
}
1846318484
}
1846418485

1846518486
// Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -20413,9 +20434,19 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
2041320434
}
2041420435

2041520436
// Set condition code (CC) flags.
20416-
SDValue Test = DAG.getNode(
20417-
Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST,
20418-
DL, MVT::Other, Pg, Op);
20437+
AArch64ISD::NodeType NT;
20438+
switch (Cond) {
20439+
default:
20440+
NT = AArch64ISD::PTEST;
20441+
break;
20442+
case AArch64CC::ANY_ACTIVE:
20443+
NT = AArch64ISD::PTEST_ANY;
20444+
break;
20445+
case AArch64CC::FIRST_ACTIVE:
20446+
NT = AArch64ISD::PTEST_FIRST;
20447+
break;
20448+
}
20449+
SDValue Test = DAG.getNode(NT, DL, MVT::Other, Pg, Op);
2041920450

2042020451
// Convert CC to integer based on requested condition.
2042120452
// NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ enum NodeType : unsigned {
346346
INSR,
347347
PTEST,
348348
PTEST_ANY,
349+
PTEST_FIRST,
349350
PTRUE,
350351

351352
CTTZ_ELTS,

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,7 @@ bool AArch64InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg,
11811181
break;
11821182
case AArch64::PTEST_PP:
11831183
case AArch64::PTEST_PP_ANY:
1184+
case AArch64::PTEST_PP_FIRST:
11841185
SrcReg = MI.getOperand(0).getReg();
11851186
SrcReg2 = MI.getOperand(1).getReg();
11861187
// Not sure about the mask and value for now...
@@ -1351,12 +1352,25 @@ static bool areCFlagsAccessedBetweenInstrs(
13511352
return false;
13521353
}
13531354

1354-
std::pair<bool, unsigned>
1355+
std::tuple<bool, unsigned, MachineInstr *>
13551356
AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
13561357
MachineInstr *Pred,
13571358
const MachineRegisterInfo *MRI) const {
13581359
unsigned MaskOpcode = Mask->getOpcode();
13591360
unsigned PredOpcode = Pred->getOpcode();
1361+
1362+
// Handle a COPY from the LSB of the results of paired WHILEcc instruction.
1363+
if ((PredOpcode == TargetOpcode::COPY &&
1364+
Pred->getOperand(1).getSubReg() == AArch64::psub0) ||
1365+
// Handle unpack of the LSB of the result of a WHILEcc instruction.
1366+
PredOpcode == AArch64::PUNPKLO_PP) {
1367+
MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
1368+
if (MI && isWhileOpcode(MI->getOpcode())) {
1369+
Pred = MI;
1370+
PredOpcode = MI->getOpcode();
1371+
}
1372+
}
1373+
13601374
bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
13611375
bool PredIsWhileLike = isWhileOpcode(PredOpcode);
13621376

@@ -1365,17 +1379,18 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
13651379
// instruction and the condition is "any" since WHILcc does an implicit
13661380
// PTEST(ALL, PG) check and PG is always a subset of ALL.
13671381
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
1368-
return {true, 0};
1382+
return {true, 0, Pred};
13691383

1370-
// For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
1371-
// redundant since WHILE performs an implicit PTEST with an all active
1372-
// mask.
1384+
// For PTEST(PTRUE_ALL, WHILE), since WHILE performs an implicit PTEST
1385+
// with an all active mask, the PTEST is redundant if ether the element
1386+
// size matches or the PTEST condition is "first".
13731387
if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
1374-
getElementSizeForOpcode(MaskOpcode) ==
1375-
getElementSizeForOpcode(PredOpcode))
1376-
return {true, 0};
1388+
(PTest->getOpcode() == AArch64::PTEST_PP_FIRST ||
1389+
getElementSizeForOpcode(MaskOpcode) ==
1390+
getElementSizeForOpcode(PredOpcode)))
1391+
return {true, 0, Pred};
13771392

1378-
return {false, 0};
1393+
return {false, 0, nullptr};
13791394
}
13801395

13811396
if (PredIsPTestLike) {
@@ -1384,7 +1399,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
13841399
// "any" since PG is always a subset of the governing predicate of the
13851400
// ptest-like instruction.
13861401
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
1387-
return {true, 0};
1402+
return {true, 0, Pred};
13881403

13891404
// For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
13901405
// the element size matches and either the PTEST_LIKE instruction uses
@@ -1394,7 +1409,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
13941409
getElementSizeForOpcode(PredOpcode)) {
13951410
auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
13961411
if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
1397-
return {true, 0};
1412+
return {true, 0, Pred};
13981413
}
13991414

14001415
// For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
@@ -1423,9 +1438,9 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
14231438
uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
14241439
if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
14251440
PTest->getOpcode() == AArch64::PTEST_PP_ANY))
1426-
return {true, 0};
1441+
return {true, 0, Pred};
14271442

1428-
return {false, 0};
1443+
return {false, 0, nullptr};
14291444
}
14301445

14311446
// If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
@@ -1447,7 +1462,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
14471462
// may be different and we can't remove the ptest.
14481463
auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
14491464
if (Mask != PredMask)
1450-
return {false, 0};
1465+
return {false, 0, nullptr};
14511466
break;
14521467
}
14531468
case AArch64::BRKN_PPzP: {
@@ -1456,18 +1471,18 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
14561471
// PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
14571472
if ((MaskOpcode != AArch64::PTRUE_B) ||
14581473
(Mask->getOperand(1).getImm() != 31))
1459-
return {false, 0};
1474+
return {false, 0, nullptr};
14601475
break;
14611476
}
14621477
case AArch64::PTRUE_B:
14631478
// PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
14641479
break;
14651480
default:
14661481
// Bail out if we don't recognize the input
1467-
return {false, 0};
1482+
return {false, 0, nullptr};
14681483
}
14691484

1470-
return {true, convertToFlagSettingOpc(PredOpcode)};
1485+
return {true, convertToFlagSettingOpc(PredOpcode), Pred};
14711486
}
14721487

14731488
/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
@@ -1477,7 +1492,10 @@ bool AArch64InstrInfo::optimizePTestInstr(
14771492
const MachineRegisterInfo *MRI) const {
14781493
auto *Mask = MRI->getUniqueVRegDef(MaskReg);
14791494
auto *Pred = MRI->getUniqueVRegDef(PredReg);
1480-
auto [canRemove, NewOp] = canRemovePTestInstr(PTest, Mask, Pred, MRI);
1495+
bool canRemove;
1496+
unsigned NewOp;
1497+
std::tie(canRemove, NewOp, Pred) =
1498+
canRemovePTestInstr(PTest, Mask, Pred, MRI);
14811499
if (!canRemove)
14821500
return false;
14831501

@@ -1554,7 +1572,8 @@ bool AArch64InstrInfo::optimizeCompareInstr(
15541572
}
15551573

15561574
if (CmpInstr.getOpcode() == AArch64::PTEST_PP ||
1557-
CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY)
1575+
CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY ||
1576+
CmpInstr.getOpcode() == AArch64::PTEST_PP_FIRST)
15581577
return optimizePTestInstr(&CmpInstr, SrcReg, SrcReg2, MRI);
15591578

15601579
if (SrcReg2 != 0)

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
432432
bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
433433
unsigned PredReg,
434434
const MachineRegisterInfo *MRI) const;
435-
std::pair<bool, unsigned>
435+
std::tuple<bool, unsigned, MachineInstr *>
436436
canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
437437
MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
438438
};

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,10 @@ def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
373373
(AArch64fadda_p_node (SVEAllActive), node:$op2,
374374
(vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;
375375

376-
def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
377-
def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
378-
def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
376+
def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
377+
def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
378+
def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
379+
def AArch64ptest_first : SDNode<"AArch64ISD::PTEST_FIRST", SDT_AArch64PTest>;
379380

380381
def SDT_AArch64DUP_PRED : SDTypeProfile<1, 3,
381382
[SDTCisVec<0>, SDTCisSameAs<0, 3>, SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0, 1>]>;
@@ -917,7 +918,7 @@ let Predicates = [HasSVEorSME] in {
917918
defm BRKB_PPmP : sve_int_break_m<0b101, "brkb", int_aarch64_sve_brkb>;
918919
defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>;
919920

920-
defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any>;
921+
defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any, AArch64ptest_first>;
921922
defm PFALSE : sve_int_pfalse<0b000000, "pfalse">;
922923
defm PFIRST : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>;
923924
defm PNEXT : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>;

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -792,13 +792,16 @@ class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
792792
}
793793

794794
multiclass sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op,
795-
SDPatternOperator op_any> {
795+
SDPatternOperator op_any, SDPatternOperator op_first> {
796796
def NAME : sve_int_ptest<opc, asm, op>;
797797

798798
let hasNoSchedulingInfo = 1, isCompare = 1, Defs = [NZCV] in {
799799
def _ANY : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
800800
[(op_any (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
801801
PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
802+
def _FIRST : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
803+
[(op_first (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
804+
PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
802805
}
803806
}
804807

@@ -9657,7 +9660,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
96579660

96589661
// SVE integer compare scalar count and limit (predicate pair)
96599662
class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
9660-
RegisterOperand ppr_ty>
9663+
RegisterOperand ppr_ty, ElementSizeEnum EltSz>
96619664
: I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
96629665
mnemonic, "\t$Pd, $Rn, $Rm",
96639666
"", []>, Sched<[]> {
@@ -9675,16 +9678,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
96759678
let Inst{3-1} = Pd;
96769679
let Inst{0} = opc{0};
96779680

9681+
let ElementSize = EltSz;
96789682
let Defs = [NZCV];
96799683
let hasSideEffects = 0;
9684+
let isWhile = 1;
96809685
}
96819686

96829687

96839688
multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
9684-
def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
9685-
def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
9686-
def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
9687-
def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
9689+
def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
9690+
def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
9691+
def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
9692+
def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
96889693
}
96899694

96909695

0 commit comments

Comments
 (0)