Skip to content

Commit ccaf625

Browse files
[AArch64] Optimise test of the LSB of a paired whileCC insntruction
1 parent bd6adbb commit ccaf625

File tree

7 files changed

+123
-82
lines changed

7 files changed

+123
-82
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2682,6 +2682,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26822682
MAKE_CASE(AArch64ISD::INSR)
26832683
MAKE_CASE(AArch64ISD::PTEST)
26842684
MAKE_CASE(AArch64ISD::PTEST_ANY)
2685+
MAKE_CASE(AArch64ISD::PTEST_FIRST)
26852686
MAKE_CASE(AArch64ISD::PTRUE)
26862687
MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
26872688
MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
@@ -18515,21 +18516,41 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
1851518516
AArch64CC::CondCode Cond);
1851618517

1851718518
static bool isPredicateCCSettingOp(SDValue N) {
18518-
if ((N.getOpcode() == ISD::SETCC) ||
18519-
(N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
18520-
(N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
18521-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
18522-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
18523-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
18524-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
18525-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
18526-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
18527-
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
18528-
// get_active_lane_mask is lowered to a whilelo instruction.
18529-
N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
18519+
if (N.getOpcode() == ISD::SETCC)
1853018520
return true;
1853118521

18532-
return false;
18522+
if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
18523+
isNullConstant(N.getOperand(1)))
18524+
N = N.getOperand(0);
18525+
18526+
if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
18527+
return false;
18528+
18529+
switch (N.getConstantOperandVal(0)) {
18530+
default:
18531+
return false;
18532+
case Intrinsic::aarch64_sve_whilege_x2:
18533+
case Intrinsic::aarch64_sve_whilegt_x2:
18534+
case Intrinsic::aarch64_sve_whilehi_x2:
18535+
case Intrinsic::aarch64_sve_whilehs_x2:
18536+
case Intrinsic::aarch64_sve_whilele_x2:
18537+
case Intrinsic::aarch64_sve_whilelo_x2:
18538+
case Intrinsic::aarch64_sve_whilels_x2:
18539+
case Intrinsic::aarch64_sve_whilelt_x2:
18540+
if (N.getResNo() != 0)
18541+
return false;
18542+
[[fallthrough]];
18543+
case Intrinsic::aarch64_sve_whilege:
18544+
case Intrinsic::aarch64_sve_whilegt:
18545+
case Intrinsic::aarch64_sve_whilehi:
18546+
case Intrinsic::aarch64_sve_whilehs:
18547+
case Intrinsic::aarch64_sve_whilele:
18548+
case Intrinsic::aarch64_sve_whilelo:
18549+
case Intrinsic::aarch64_sve_whilels:
18550+
case Intrinsic::aarch64_sve_whilelt:
18551+
case Intrinsic::get_active_lane_mask:
18552+
return true;
18553+
}
1853318554
}
1853418555

1853518556
// Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
@@ -20483,9 +20504,19 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
2048320504
}
2048420505

2048520506
// Set condition code (CC) flags.
20486-
SDValue Test = DAG.getNode(
20487-
Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST,
20488-
DL, MVT::Other, Pg, Op);
20507+
AArch64ISD::NodeType NT;
20508+
switch (Cond) {
20509+
default:
20510+
NT = AArch64ISD::PTEST;
20511+
break;
20512+
case AArch64CC::ANY_ACTIVE:
20513+
NT = AArch64ISD::PTEST_ANY;
20514+
break;
20515+
case AArch64CC::FIRST_ACTIVE:
20516+
NT = AArch64ISD::PTEST_FIRST;
20517+
break;
20518+
}
20519+
SDValue Test = DAG.getNode(NT, DL, MVT::Other, Pg, Op);
2048920520

2049020521
// Convert CC to integer based on requested condition.
2049120522
// NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ enum NodeType : unsigned {
346346
INSR,
347347
PTEST,
348348
PTEST_ANY,
349+
PTEST_FIRST,
349350
PTRUE,
350351

351352
CTTZ_ELTS,

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@ bool AArch64InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg,
11831183
break;
11841184
case AArch64::PTEST_PP:
11851185
case AArch64::PTEST_PP_ANY:
1186+
case AArch64::PTEST_PP_FIRST:
11861187
SrcReg = MI.getOperand(0).getReg();
11871188
SrcReg2 = MI.getOperand(1).getReg();
11881189
// Not sure about the mask and value for now...
@@ -1354,12 +1355,25 @@ static bool areCFlagsAccessedBetweenInstrs(
13541355
return false;
13551356
}
13561357

1357-
std::pair<bool, unsigned>
1358+
std::tuple<bool, unsigned, MachineInstr *>
13581359
AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
13591360
MachineInstr *Pred,
13601361
const MachineRegisterInfo *MRI) const {
13611362
unsigned MaskOpcode = Mask->getOpcode();
13621363
unsigned PredOpcode = Pred->getOpcode();
1364+
1365+
// Handle a COPY from the LSB of the results of paired WHILEcc instruction.
1366+
if ((PredOpcode == TargetOpcode::COPY &&
1367+
Pred->getOperand(1).getSubReg() == AArch64::psub0) ||
1368+
// Handle unpack of the LSB of the result of a WHILEcc instruction.
1369+
PredOpcode == AArch64::PUNPKLO_PP) {
1370+
MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
1371+
if (MI && isWhileOpcode(MI->getOpcode())) {
1372+
Pred = MI;
1373+
PredOpcode = MI->getOpcode();
1374+
}
1375+
}
1376+
13631377
bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
13641378
bool PredIsWhileLike = isWhileOpcode(PredOpcode);
13651379

@@ -1368,17 +1382,18 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
13681382
// instruction and the condition is "any" since WHILcc does an implicit
13691383
// PTEST(ALL, PG) check and PG is always a subset of ALL.
13701384
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
1371-
return {true, 0};
1385+
return {true, 0, Pred};
13721386

1373-
// For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
1374-
// redundant since WHILE performs an implicit PTEST with an all active
1375-
// mask.
1387+
// For PTEST(PTRUE_ALL, WHILE), since WHILE performs an implicit PTEST
1388+
// with an all active mask, the PTEST is redundant if ether the element
1389+
// size matches or the PTEST condition is "first".
13761390
if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
1377-
getElementSizeForOpcode(MaskOpcode) ==
1378-
getElementSizeForOpcode(PredOpcode))
1379-
return {true, 0};
1391+
(PTest->getOpcode() == AArch64::PTEST_PP_FIRST ||
1392+
getElementSizeForOpcode(MaskOpcode) ==
1393+
getElementSizeForOpcode(PredOpcode)))
1394+
return {true, 0, Pred};
13801395

1381-
return {false, 0};
1396+
return {false, 0, nullptr};
13821397
}
13831398

13841399
if (PredIsPTestLike) {
@@ -1387,7 +1402,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
13871402
// "any" since PG is always a subset of the governing predicate of the
13881403
// ptest-like instruction.
13891404
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
1390-
return {true, 0};
1405+
return {true, 0, Pred};
13911406

13921407
// For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
13931408
// the element size matches and either the PTEST_LIKE instruction uses
@@ -1397,7 +1412,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
13971412
getElementSizeForOpcode(PredOpcode)) {
13981413
auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
13991414
if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
1400-
return {true, 0};
1415+
return {true, 0, Pred};
14011416
}
14021417

14031418
// For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
@@ -1426,9 +1441,9 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
14261441
uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
14271442
if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
14281443
PTest->getOpcode() == AArch64::PTEST_PP_ANY))
1429-
return {true, 0};
1444+
return {true, 0, Pred};
14301445

1431-
return {false, 0};
1446+
return {false, 0, nullptr};
14321447
}
14331448

14341449
// If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
@@ -1450,7 +1465,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
14501465
// may be different and we can't remove the ptest.
14511466
auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
14521467
if (Mask != PredMask)
1453-
return {false, 0};
1468+
return {false, 0, nullptr};
14541469
break;
14551470
}
14561471
case AArch64::BRKN_PPzP: {
@@ -1459,18 +1474,18 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
14591474
// PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
14601475
if ((MaskOpcode != AArch64::PTRUE_B) ||
14611476
(Mask->getOperand(1).getImm() != 31))
1462-
return {false, 0};
1477+
return {false, 0, nullptr};
14631478
break;
14641479
}
14651480
case AArch64::PTRUE_B:
14661481
// PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
14671482
break;
14681483
default:
14691484
// Bail out if we don't recognize the input
1470-
return {false, 0};
1485+
return {false, 0, nullptr};
14711486
}
14721487

1473-
return {true, convertToFlagSettingOpc(PredOpcode)};
1488+
return {true, convertToFlagSettingOpc(PredOpcode), Pred};
14741489
}
14751490

14761491
/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
@@ -1480,7 +1495,10 @@ bool AArch64InstrInfo::optimizePTestInstr(
14801495
const MachineRegisterInfo *MRI) const {
14811496
auto *Mask = MRI->getUniqueVRegDef(MaskReg);
14821497
auto *Pred = MRI->getUniqueVRegDef(PredReg);
1483-
auto [canRemove, NewOp] = canRemovePTestInstr(PTest, Mask, Pred, MRI);
1498+
bool canRemove;
1499+
unsigned NewOp;
1500+
std::tie(canRemove, NewOp, Pred) =
1501+
canRemovePTestInstr(PTest, Mask, Pred, MRI);
14841502
if (!canRemove)
14851503
return false;
14861504

@@ -1558,7 +1576,8 @@ bool AArch64InstrInfo::optimizeCompareInstr(
15581576
}
15591577

15601578
if (CmpInstr.getOpcode() == AArch64::PTEST_PP ||
1561-
CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY)
1579+
CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY ||
1580+
CmpInstr.getOpcode() == AArch64::PTEST_PP_FIRST)
15621581
return optimizePTestInstr(&CmpInstr, SrcReg, SrcReg2, MRI);
15631582

15641583
if (SrcReg2 != 0)

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
572572
bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
573573
unsigned PredReg,
574574
const MachineRegisterInfo *MRI) const;
575-
std::pair<bool, unsigned>
575+
std::tuple<bool, unsigned, MachineInstr *>
576576
canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
577577
MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
578578
};

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,10 @@ def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
373373
(AArch64fadda_p_node (SVEAllActive), node:$op2,
374374
(vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;
375375

376-
def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
377-
def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
378-
def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
376+
def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
377+
def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
378+
def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
379+
def AArch64ptest_first : SDNode<"AArch64ISD::PTEST_FIRST", SDT_AArch64PTest>;
379380

380381
def SDT_AArch64DUP_PRED : SDTypeProfile<1, 3,
381382
[SDTCisVec<0>, SDTCisSameAs<0, 3>, SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0, 1>]>;
@@ -948,7 +949,7 @@ let Predicates = [HasSVEorSME] in {
948949
defm BRKB_PPmP : sve_int_break_m<0b101, "brkb", int_aarch64_sve_brkb>;
949950
defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>;
950951

951-
defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any>;
952+
defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any, AArch64ptest_first>;
952953
defm PFALSE : sve_int_pfalse<0b000000, "pfalse">;
953954
defm PFIRST : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>;
954955
defm PNEXT : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>;

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -784,13 +784,16 @@ class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
784784
}
785785

786786
multiclass sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op,
787-
SDPatternOperator op_any> {
787+
SDPatternOperator op_any, SDPatternOperator op_first> {
788788
def NAME : sve_int_ptest<opc, asm, op>;
789789

790790
let hasNoSchedulingInfo = 1, isCompare = 1, Defs = [NZCV] in {
791791
def _ANY : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
792792
[(op_any (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
793793
PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
794+
def _FIRST : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
795+
[(op_first (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
796+
PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
794797
}
795798
}
796799

@@ -9669,7 +9672,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {
96699672

96709673
// SVE integer compare scalar count and limit (predicate pair)
96719674
class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
9672-
RegisterOperand ppr_ty>
9675+
RegisterOperand ppr_ty, ElementSizeEnum EltSz>
96739676
: I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
96749677
mnemonic, "\t$Pd, $Rn, $Rm",
96759678
"", []>, Sched<[]> {
@@ -9687,16 +9690,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
96879690
let Inst{3-1} = Pd;
96889691
let Inst{0} = opc{0};
96899692

9693+
let ElementSize = EltSz;
96909694
let Defs = [NZCV];
96919695
let hasSideEffects = 0;
9696+
let isWhile = 1;
96929697
}
96939698

96949699

96959700
multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
9696-
def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
9697-
def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
9698-
def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
9699-
def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
9701+
def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
9702+
def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
9703+
def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
9704+
def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
97009705
}
97019706

97029707

0 commit comments

Comments
 (0)