diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 9f6f66e9e0c70..e8cf16e28b437 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2727,6 +2727,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::INSR) MAKE_CASE(AArch64ISD::PTEST) MAKE_CASE(AArch64ISD::PTEST_ANY) + MAKE_CASE(AArch64ISD::PTEST_FIRST) MAKE_CASE(AArch64ISD::PTRUE) MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO) MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO) @@ -18733,21 +18734,41 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, AArch64CC::CondCode Cond); static bool isPredicateCCSettingOp(SDValue N) { - if ((N.getOpcode() == ISD::SETCC) || - (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN && - (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege || - N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt || - N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi || - N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs || - N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele || - N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo || - N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels || - N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt || - // get_active_lane_mask is lowered to a whilelo instruction. - N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask))) + if (N.getOpcode() == ISD::SETCC) return true; - return false; + if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR && + isNullConstant(N.getOperand(1))) + N = N.getOperand(0); + + if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN) + return false; + + switch (N.getConstantOperandVal(0)) { + default: + return false; + case Intrinsic::aarch64_sve_whilege_x2: + case Intrinsic::aarch64_sve_whilegt_x2: + case Intrinsic::aarch64_sve_whilehi_x2: + case Intrinsic::aarch64_sve_whilehs_x2: + case Intrinsic::aarch64_sve_whilele_x2: + case Intrinsic::aarch64_sve_whilelo_x2: + case Intrinsic::aarch64_sve_whilels_x2: + case Intrinsic::aarch64_sve_whilelt_x2: + if (N.getResNo() != 0) + return false; + [[fallthrough]]; + case Intrinsic::aarch64_sve_whilege: + case Intrinsic::aarch64_sve_whilegt: + case Intrinsic::aarch64_sve_whilehi: + case Intrinsic::aarch64_sve_whilehs: + case Intrinsic::aarch64_sve_whilele: + case Intrinsic::aarch64_sve_whilelo: + case Intrinsic::aarch64_sve_whilels: + case Intrinsic::aarch64_sve_whilelt: + case Intrinsic::get_active_lane_mask: + return true; + } } // Materialize : i1 = extract_vector_elt t37, Constant:i64<0> @@ -20666,9 +20687,19 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, } // Set condition code (CC) flags. - SDValue Test = DAG.getNode( - Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST, - DL, MVT::Other, Pg, Op); + AArch64ISD::NodeType NT; + switch (Cond) { + default: + NT = AArch64ISD::PTEST; + break; + case AArch64CC::ANY_ACTIVE: + NT = AArch64ISD::PTEST_ANY; + break; + case AArch64CC::FIRST_ACTIVE: + NT = AArch64ISD::PTEST_FIRST; + break; + } + SDValue Test = DAG.getNode(NT, DL, MVT::Other, Pg, Op); // Convert CC to integer based on requested condition. // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 986f1b67ee513..cb1774e193aad 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -359,6 +359,7 @@ enum NodeType : unsigned { INSR, PTEST, PTEST_ANY, + PTEST_FIRST, PTRUE, CTTZ_ELTS, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 949e7699d070d..bad1f63c83da4 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -1184,6 +1184,7 @@ bool AArch64InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg, break; case AArch64::PTEST_PP: case AArch64::PTEST_PP_ANY: + case AArch64::PTEST_PP_FIRST: SrcReg = MI.getOperand(0).getReg(); SrcReg2 = MI.getOperand(1).getReg(); // Not sure about the mask and value for now... @@ -1355,12 +1356,25 @@ static bool areCFlagsAccessedBetweenInstrs( return false; } -std::optional +std::optional> AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask, MachineInstr *Pred, const MachineRegisterInfo *MRI) const { unsigned MaskOpcode = Mask->getOpcode(); unsigned PredOpcode = Pred->getOpcode(); + + // Handle a COPY from the LSB of the results of paired WHILEcc instruction. + if ((PredOpcode == TargetOpcode::COPY && + Pred->getOperand(1).getSubReg() == AArch64::psub0) || + // Handle unpack of the LSB of the result of a WHILEcc instruction. + PredOpcode == AArch64::PUNPKLO_PP) { + MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg()); + if (MI && isWhileOpcode(MI->getOpcode())) { + Pred = MI; + PredOpcode = MI->getOpcode(); + } + } + bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode); bool PredIsWhileLike = isWhileOpcode(PredOpcode); @@ -1369,15 +1383,16 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask, // instruction and the condition is "any" since WHILcc does an implicit // PTEST(ALL, PG) check and PG is always a subset of ALL. if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY) - return PredOpcode; + return std::make_pair(PredOpcode, Pred); - // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is - // redundant since WHILE performs an implicit PTEST with an all active - // mask. + // For PTEST(PTRUE_ALL, WHILE), since WHILE performs an implicit PTEST + // with an all active mask, the PTEST is redundant if ether the element + // size matches or the PTEST condition is "first". if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 && - getElementSizeForOpcode(MaskOpcode) == - getElementSizeForOpcode(PredOpcode)) - return PredOpcode; + (PTest->getOpcode() == AArch64::PTEST_PP_FIRST || + getElementSizeForOpcode(MaskOpcode) == + getElementSizeForOpcode(PredOpcode))) + return std::make_pair(PredOpcode, Pred); return {}; } @@ -1388,7 +1403,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask, // "any" since PG is always a subset of the governing predicate of the // ptest-like instruction. if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY) - return PredOpcode; + return std::make_pair(PredOpcode, Pred); // For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the // the element size matches and either the PTEST_LIKE instruction uses @@ -1398,7 +1413,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask, getElementSizeForOpcode(PredOpcode)) { auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg()); if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY) - return PredOpcode; + return std::make_pair(PredOpcode, Pred); } // For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the @@ -1427,7 +1442,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask, uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode); if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB || PTest->getOpcode() == AArch64::PTEST_PP_ANY)) - return PredOpcode; + return std::make_pair(PredOpcode, Pred); return {}; } @@ -1471,7 +1486,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask, return {}; } - return convertToFlagSettingOpc(PredOpcode); + return std::make_pair(convertToFlagSettingOpc(PredOpcode), Pred); } /// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating @@ -1481,10 +1496,12 @@ bool AArch64InstrInfo::optimizePTestInstr( const MachineRegisterInfo *MRI) const { auto *Mask = MRI->getUniqueVRegDef(MaskReg); auto *Pred = MRI->getUniqueVRegDef(PredReg); + unsigned NewOp; unsigned PredOpcode = Pred->getOpcode(); - auto NewOp = canRemovePTestInstr(PTest, Mask, Pred, MRI); - if (!NewOp) + auto canRemove = canRemovePTestInstr(PTest, Mask, Pred, MRI); + if (!canRemove) return false; + std::tie(NewOp, Pred) = *canRemove; const TargetRegisterInfo *TRI = &getRegisterInfo(); @@ -1498,8 +1515,8 @@ bool AArch64InstrInfo::optimizePTestInstr( // operand to be replaced with an equivalent instruction that also sets the // flags. PTest->eraseFromParent(); - if (*NewOp != PredOpcode) { - Pred->setDesc(get(*NewOp)); + if (NewOp != PredOpcode) { + Pred->setDesc(get(NewOp)); bool succeeded = UpdateOperandRegClass(*Pred); (void)succeeded; assert(succeeded && "Operands have incompatible register classes!"); @@ -1560,7 +1577,8 @@ bool AArch64InstrInfo::optimizeCompareInstr( } if (CmpInstr.getOpcode() == AArch64::PTEST_PP || - CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY) + CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY || + CmpInstr.getOpcode() == AArch64::PTEST_PP_FIRST) return optimizePTestInstr(&CmpInstr, SrcReg, SrcReg2, MRI); if (SrcReg2 != 0) diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h index 792e0c3063b10..d722f433a150b 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -572,7 +572,8 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo { bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg, unsigned PredReg, const MachineRegisterInfo *MRI) const; - std::optional + + std::optional> canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask, MachineInstr *Pred, const MachineRegisterInfo *MRI) const; }; diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index bd5de628d8529..3cee3e92fae08 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -373,9 +373,10 @@ def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3), (AArch64fadda_p_node (SVEAllActive), node:$op2, (vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>; -def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>; -def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>; -def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>; +def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>; +def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>; +def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>; +def AArch64ptest_first : SDNode<"AArch64ISD::PTEST_FIRST", SDT_AArch64PTest>; def SDT_AArch64DUP_PRED : SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisSameAs<0, 3>, SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0, 1>]>; @@ -948,7 +949,7 @@ let Predicates = [HasSVEorSME] in { defm BRKB_PPmP : sve_int_break_m<0b101, "brkb", int_aarch64_sve_brkb>; defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>; - defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any>; + defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any, AArch64ptest_first>; defm PFALSE : sve_int_pfalse<0b000000, "pfalse">; defm PFIRST : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>; defm PNEXT : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index fc7d3cdda4acd..1c3528bed08c4 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -784,13 +784,16 @@ class sve_int_ptest opc, string asm, SDPatternOperator op> } multiclass sve_int_ptest opc, string asm, SDPatternOperator op, - SDPatternOperator op_any> { + SDPatternOperator op_any, SDPatternOperator op_first> { def NAME : sve_int_ptest; let hasNoSchedulingInfo = 1, isCompare = 1, Defs = [NZCV] in { def _ANY : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn), [(op_any (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>, PseudoInstExpansion<(!cast(NAME) PPRAny:$Pg, PPR8:$Pn)>; + def _FIRST : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn), + [(op_first (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>, + PseudoInstExpansion<(!cast(NAME) PPRAny:$Pg, PPR8:$Pn)>; } } @@ -9669,7 +9672,7 @@ multiclass sve2p1_int_while_rr_pn opc> { // SVE integer compare scalar count and limit (predicate pair) class sve2p1_int_while_rr_pair sz, bits<3> opc, - RegisterOperand ppr_ty> + RegisterOperand ppr_ty, ElementSizeEnum EltSz> : I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm), mnemonic, "\t$Pd, $Rn, $Rm", "", []>, Sched<[]> { @@ -9687,16 +9690,18 @@ class sve2p1_int_while_rr_pair sz, bits<3> opc, let Inst{3-1} = Pd; let Inst{0} = opc{0}; + let ElementSize = EltSz; let Defs = [NZCV]; let hasSideEffects = 0; + let isWhile = 1; } multiclass sve2p1_int_while_rr_pair opc> { - def _B : sve2p1_int_while_rr_pair; - def _H : sve2p1_int_while_rr_pair; - def _S : sve2p1_int_while_rr_pair; - def _D : sve2p1_int_while_rr_pair; + def _B : sve2p1_int_while_rr_pair; + def _H : sve2p1_int_while_rr_pair; + def _S : sve2p1_int_while_rr_pair; + def _D : sve2p1_int_while_rr_pair; } diff --git a/llvm/test/CodeGen/AArch64/opt-while-test.ll b/llvm/test/CodeGen/AArch64/opt-while-test.ll new file mode 100644 index 0000000000000..a022f4d8c9e23 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/opt-while-test.ll @@ -0,0 +1,97 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s | FileCheck %s +; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s --check-prefix=CHECK-SVE2p1 +target triple = "aarch64-linux" + +define void @f_while(i32 %i, i32 %n) #0 { +; CHECK-LABEL: f_while: +; CHECK: // %bb.0: // %E +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: whilelo p0.b, w0, w1 +; CHECK-NEXT: b.pl .LBB0_2 +; CHECK-NEXT: // %bb.1: // %A +; CHECK-NEXT: bl g0 +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret +; CHECK-NEXT: .LBB0_2: // %B +; CHECK-NEXT: bl g1 +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret +; +; CHECK-SVE2p1-LABEL: f_while: +; CHECK-SVE2p1: // %bb.0: // %E +; CHECK-SVE2p1-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-SVE2p1-NEXT: whilelo p0.b, w0, w1 +; CHECK-SVE2p1-NEXT: b.pl .LBB0_2 +; CHECK-SVE2p1-NEXT: // %bb.1: // %A +; CHECK-SVE2p1-NEXT: bl g0 +; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-SVE2p1-NEXT: ret +; CHECK-SVE2p1-NEXT: .LBB0_2: // %B +; CHECK-SVE2p1-NEXT: bl g1 +; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-SVE2p1-NEXT: ret +E: + %wide.mask = call @llvm.get.active.lane.mask.nxv16i1.i64(i32 %i, i32 %n) + %mask = call @llvm.vector.extract.nxv8i1.nxv16i1( %wide.mask, i64 0) + %elt = extractelement %mask, i64 0 + br i1 %elt, label %A, label %B +A: + call void @g0() + ret void +B: + call void @g1() + ret void +} + +define void @f_while_x2(i32 %i, i32 %n) #0 { +; CHECK-LABEL: f_while_x2: +; CHECK: // %bb.0: // %E +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: whilelo p0.b, w0, w1 +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: b.pl .LBB1_2 +; CHECK-NEXT: // %bb.1: // %A +; CHECK-NEXT: bl g0 +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret +; CHECK-NEXT: .LBB1_2: // %B +; CHECK-NEXT: bl g1 +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret +; +; CHECK-SVE2p1-LABEL: f_while_x2: +; CHECK-SVE2p1: // %bb.0: // %E +; CHECK-SVE2p1-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-SVE2p1-NEXT: mov w8, w1 +; CHECK-SVE2p1-NEXT: mov w9, w0 +; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x9, x8 +; CHECK-SVE2p1-NEXT: b.pl .LBB1_2 +; CHECK-SVE2p1-NEXT: // %bb.1: // %A +; CHECK-SVE2p1-NEXT: mov p0.b, p1.b +; CHECK-SVE2p1-NEXT: bl g0 +; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-SVE2p1-NEXT: ret +; CHECK-SVE2p1-NEXT: .LBB1_2: // %B +; CHECK-SVE2p1-NEXT: mov p0.b, p1.b +; CHECK-SVE2p1-NEXT: bl g1 +; CHECK-SVE2p1-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-SVE2p1-NEXT: ret +E: + %wide.mask = call @llvm.get.active.lane.mask.nxv16i1.i64(i32 %i, i32 %n) + %mask.hi = call @llvm.vector.extract.nxv8i1.nxv16i1( %wide.mask, i64 8) + %mask = call @llvm.vector.extract.nxv8i1.nxv16i1( %wide.mask, i64 0) + %elt = extractelement %mask, i64 0 + br i1 %elt, label %A, label %B +A: + call void @g0( %mask.hi) + ret void +B: + call void @g1( %mask.hi) + ret void +} + +declare void @g0(...) +declare void @g1(...) + +attributes #0 = { nounwind vscale_range(1,16) "target-cpu"="neoverse-v1" }