Skip to content

[AArch64] Refactor redundant PTEST optimisations (NFC) #87802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 18, 2024

Conversation

momchil-velikov
Copy link
Collaborator

This patch refactors AArch64InstrInfo::optimizePTestInstr to simplify the convoluted conditions and control flow
and make it easier to add the optimisation in #81141

@llvmbot
Copy link
Member

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Momchil Velikov (momchil-velikov)

Changes

This patch refactors AArch64InstrInfo::optimizePTestInstr to simplify the convoluted conditions and control flow
and make it easier to add the optimisation in #81141


Patch is 474.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87802.diff

28 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+10)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+59-2)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+92-78)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.h (+3)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+8)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+6-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.cpp (+4-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+47-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (+29-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+86-21)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+8-10)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (modified) llvm/test/CodeGen/AArch64/active_lane_mask.ll (+1)
  • (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+395)
  • (added) llvm/test/CodeGen/AArch64/sve-wide-lane-mask.ll (+1069)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll (+40-55)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/pr73894.ll (+2-2)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll (+842-830)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding-unroll.ll (+165-157)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/sve-wide-lane-mask.ll (+656)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll (+48-54)
  • (modified) llvm/test/Transforms/LoopVectorize/ARM/tail-folding-prefer-flag.ll (+3-3)
  • (modified) llvm/test/Transforms/LoopVectorize/strict-fadd-interleave-only.ll (+2-2)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index fa9392b86c15b9..eae993e60008e4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1241,6 +1241,8 @@ class TargetTransformInfo {
   /// and the number of execution units in the CPU.
   unsigned getMaxInterleaveFactor(ElementCount VF) const;
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2.
   static OperandValueInfo getOperandInfo(const Value *V);
 
@@ -1999,6 +2001,9 @@ class TargetTransformInfo::Concept {
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
+
+  virtual ElementCount getMaxPredicateLength(ElementCount VF) const = 0;
+
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       OperandValueInfo Opd1Info, OperandValueInfo Opd2Info,
@@ -2622,6 +2627,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+
+  ElementCount getMaxPredicateLength(ElementCount VF) const override {
+    return Impl.getMaxPredicateLength(VF);
+  }
+
   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                             unsigned &JTSize,
                                             ProfileSummaryInfo *PSI,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 63c2ef8912b29c..1c4f2f963e89f3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -531,6 +531,8 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info, TTI::OperandValueInfo Opd2Info,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 42d8f74fd427fb..297068943a63c6 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -890,6 +890,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 5f933b4587843c..39bf57737890c8 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -816,6 +816,10 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
 
+ElementCount TargetTransformInfo::getMaxPredicateLength(ElementCount VF) const {
+  return TTIImpl->getMaxPredicateLength(VF);
+}
+
 TargetTransformInfo::OperandValueInfo
 TargetTransformInfo::getOperandInfo(const Value *V) {
   OperandValueKind OpInfo = OK_AnyValue;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8218960406ec13..d0afa0a5c65330 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1873,8 +1873,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
@@ -20507,6 +20507,61 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
+static SDValue tryCombineWhileLo(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const AArch64Subtarget *Subtarget) {
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
+  if (HalfSize < 2)
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  uint64_t OffLo, OffHi;
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Hi->getOperand(1).getNode(), OffHi))
+    return SDValue();
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
+    return SDValue();
+
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc DL(N);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  SDValue Idx = N->getOperand(1);
+  SDValue TC = N->getOperand(2);
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  }
+  auto R =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+  DCI.CombineTo(Lo, R.getValue(0));
+  DCI.CombineTo(Hi, R.getValue(1));
+
+  return SDValue(N, 0);
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -20837,6 +20892,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_ptest_last:
     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
                     AArch64CC::LAST_ACTIVE);
+  case Intrinsic::aarch64_sve_whilelo:
+    return tryCombineWhileLo(N, DCI, Subtarget);
   }
   return SDValue();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 22687b0e31c284..1b422969379d25 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1351,48 +1351,52 @@ static bool areCFlagsAccessedBetweenInstrs(
   return false;
 }
 
-/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
-/// operation which could set the flags in an identical manner
-bool AArch64InstrInfo::optimizePTestInstr(
-    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
-    const MachineRegisterInfo *MRI) const {
-  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
-  auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
-  bool OpChanged = false;
-
+std::pair<bool, unsigned>
+AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                                      MachineInstr *Pred,
+                                      const MachineRegisterInfo *MRI) const {
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
-  if (isPTrueOpcode(MaskOpcode) && (PredIsPTestLike || PredIsWhileLike) &&
-      getElementSizeForOpcode(MaskOpcode) ==
-          getElementSizeForOpcode(PredOpcode) &&
-      Mask->getOperand(1).getImm() == 31) {
+  if (PredIsWhileLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of a WHILEcc
+    // instruction and the condition is "any" since WHILcc does an implicit
+    // PTEST(ALL, PG) check and PG is always a subset of ALL.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return {true, 0};
+
     // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
     // redundant since WHILE performs an implicit PTEST with an all active
-    // mask. Must be an all active predicate of matching element size.
+    // mask.
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode))
+      return {true, 0};
+
+    return {false, 0};
+  }
+
+  if (PredIsPTestLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
+    // instruction that sets the flags as PTEST would and the condition is
+    // "any" since PG is always a subset of the governing predicate of the
+    // ptest-like instruction.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return {true, 0};
 
     // For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
-    // PTEST_LIKE instruction uses the same all active mask and the element
-    // size matches. If the PTEST has a condition of any then it is always
-    // redundant.
-    if (PredIsPTestLike) {
+    // the element size matches and either the PTEST_LIKE instruction uses
+    // the same all active mask or the condition is "any".
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode)) {
       auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PTestLikeMask && PTest->getOpcode() != AArch64::PTEST_PP_ANY)
-        return false;
+      if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+        return {true, 0};
     }
 
-    // Fallthough to simply remove the PTEST.
-  } else if ((Mask == Pred) && (PredIsPTestLike || PredIsWhileLike) &&
-             PTest->getOpcode() == AArch64::PTEST_PP_ANY) {
-    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
-    // instruction that sets the flags as PTEST would. This is only valid when
-    // the condition is any.
-
-    // Fallthough to simply remove the PTEST.
-  } else if (PredIsPTestLike) {
     // For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
     // flags are set based on the same mask 'PG', but PTEST_LIKE must operate
     // on 8-bit predicates like the PTEST.  Otherwise, for instructions like
@@ -1417,55 +1421,65 @@ bool AArch64InstrInfo::optimizePTestInstr(
     // identical regardless of element size.
     auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
     uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
-    if ((Mask != PTestLikeMask) ||
-        (PredElementSize != AArch64::ElementSizeB &&
-         PTest->getOpcode() != AArch64::PTEST_PP_ANY))
-      return false;
+    if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
+                                  PTest->getOpcode() == AArch64::PTEST_PP_ANY))
+      return {true, 0};
 
-    // Fallthough to simply remove the PTEST.
-  } else {
-    // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
-    // opcode so the PTEST becomes redundant.
-    switch (PredOpcode) {
-    case AArch64::AND_PPzPP:
-    case AArch64::BIC_PPzPP:
-    case AArch64::EOR_PPzPP:
-    case AArch64::NAND_PPzPP:
-    case AArch64::NOR_PPzPP:
-    case AArch64::ORN_PPzPP:
-    case AArch64::ORR_PPzPP:
-    case AArch64::BRKA_PPzP:
-    case AArch64::BRKPA_PPzPP:
-    case AArch64::BRKB_PPzP:
-    case AArch64::BRKPB_PPzPP:
-    case AArch64::RDFFR_PPz: {
-      // Check to see if our mask is the same. If not the resulting flag bits
-      // may be different and we can't remove the ptest.
-      auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PredMask)
-        return false;
-      break;
-    }
-    case AArch64::BRKN_PPzP: {
-      // BRKN uses an all active implicit mask to set flags unlike the other
-      // flag-setting instructions.
-      // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
-      if ((MaskOpcode != AArch64::PTRUE_B) ||
-          (Mask->getOperand(1).getImm() != 31))
-        return false;
-      break;
-    }
-    case AArch64::PTRUE_B:
-      // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
-      break;
-    default:
-      // Bail out if we don't recognize the input
-      return false;
-    }
+    return {false, 0};
+  }
 
-    NewOp = convertToFlagSettingOpc(PredOpcode);
-    OpChanged = true;
+  // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
+  // opcode so the PTEST becomes redundant.
+  switch (PredOpcode) {
+  case AArch64::AND_PPzPP:
+  case AArch64::BIC_PPzPP:
+  case AArch64::EOR_PPzPP:
+  case AArch64::NAND_PPzPP:
+  case AArch64::NOR_PPzPP:
+  case AArch64::ORN_PPzPP:
+  case AArch64::ORR_PPzPP:
+  case AArch64::BRKA_PPzP:
+  case AArch64::BRKPA_PPzPP:
+  case AArch64::BRKB_PPzP:
+  case AArch64::BRKPB_PPzPP:
+  case AArch64::RDFFR_PPz: {
+    // Check to see if our mask is the same. If not the resulting flag bits
+    // may be different and we can't remove the ptest.
+    auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (Mask != PredMask)
+      return {false, 0};
+    break;
   }
+  case AArch64::BRKN_PPzP: {
+    // BRKN uses an all active implicit mask to set flags unlike the other
+    // flag-setting instructions.
+    // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
+    if ((MaskOpcode != AArch64::PTRUE_B) ||
+        (Mask->getOperand(1).getImm() != 31))
+      return {false, 0};
+    break;
+  }
+  case AArch64::PTRUE_B:
+    // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
+    break;
+  default:
+    // Bail out if we don't recognize the input
+    return {false, 0};
+  }
+
+  return {true, convertToFlagSettingOpc(PredOpcode)};
+}
+
+/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
+/// operation which could set the flags in an identical manner
+bool AArch64InstrInfo::optimizePTestInstr(
+    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
+    const MachineRegisterInfo *MRI) const {
+  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
+  auto *Pred = MRI->getUniqueVRegDef(PredReg);
+  auto [canRemove, NewOp] = canRemovePTestInstr(PTest, Mask, Pred, MRI);
+  if (!canRemove)
+    return false;
 
   const TargetRegisterInfo *TRI = &getRegisterInfo();
 
@@ -1478,9 +1492,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
-  if (OpChanged) {
+  if (NewOp) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 2f10f80f4bdf70..7cc770a5b4eb49 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -432,6 +432,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
   bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
                           unsigned PredReg,
                           const MachineRegisterInfo *MRI) const;
+  std::pair<bool, unsigned>
+  canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                      MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
 };
 
 struct UsedNZCV {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ee7137b92445bb..66498817f55f73 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3362,6 +3362,15 @@ unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
 
+ElementCount AArch64TTIImpl::getMaxPredicateLength(ElementCount VF) const {
+  // Do not create masks bigger than `<vscale x 16 x i1>`.
+  unsigned N = ST->hasSVE() ? 16 : 0;
+  // Do not create masks that are more than twice the VF.
+  N = std::min(N, 2 * VF.getKnownMinValue());
+  return VF.isScalable() ? ElementCount::getScalable(N)
+                         : ElementCount::getFixed(N);
+}
+
 // For Falkor, we want to avoid having too many strided loads in a loop since
 // that can exhaust the HW prefetcher resources.  We adjust the unroller
 // MaxCount preference below to attempt to ensure unrolling doesn't create too
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index de39dea2be43e1..6501cc4a85e8d3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -157,6 +157,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index ece2a34f180cb4..a4e4bd8c2bb4b2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -196,6 +196,14 @@ class VPBuilder {
   VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
                       DebugLoc DL = {}, const Twine &Name = "");
 
+  VPValue *createGetActiveLaneMask(VPValue *IV, VPValue *TC, DebugLoc DL,
+                                   const Twine &Name = "") {
+    auto *ALM = new VPActiveLaneMaskRecipe(IV, TC, DL, Name);
+    if (BB)
+      BB->insert(ALM, InsertPt);
+    return ALM;
+  }
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 49bacb5ae6cc4e..6cf139775475c6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -597,6 +597,10 @@ class InnerLoopVectorizer {
   /// count of the original loop for both main loop and epilogue vectorization.
   void setTripCount(Value *TC) { TripCount = TC; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const {
+    return TTI->getMaxPredicateLength(VF);
+  }
+
 protected:
   friend class LoopVectorizationPlanner;
 
@@ -7525,7 +7529,8 @@ LoopVectorizationPlanner::executePlan(
   LLVM_DEBUG(BestVPlan.dump());
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+  VPTransformState State(BestVF, BestUF, TTI.getMaxPredicateLength(BestVF), LI,
+                         DT, ILV.Builder, &ILV, &BestVPlan,
                          OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 8ebd75da346546..27ff4c884cc566 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -214,12 +214,13 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
   return It;
 }
 
-VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
+VPTransformState::VPTransformState(ElementCount VF, unsigned UF,
+                                   ElementCount MaxPred, LoopInfo *LI,
                                    DominatorTree *DT, IRBui...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Momchil Velikov (momchil-velikov)

Changes

This patch refactors AArch64InstrInfo::optimizePTestInstr to simplify the convoluted conditions and control flow
and make it easier to add the optimisation in #81141


Patch is 474.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87802.diff

28 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+10)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+59-2)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+92-78)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.h (+3)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+8)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+6-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.cpp (+4-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+47-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (+29-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+86-21)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+8-10)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (modified) llvm/test/CodeGen/AArch64/active_lane_mask.ll (+1)
  • (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+395)
  • (added) llvm/test/CodeGen/AArch64/sve-wide-lane-mask.ll (+1069)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll (+40-55)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/pr73894.ll (+2-2)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll (+842-830)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding-unroll.ll (+165-157)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/sve-wide-lane-mask.ll (+656)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll (+48-54)
  • (modified) llvm/test/Transforms/LoopVectorize/ARM/tail-folding-prefer-flag.ll (+3-3)
  • (modified) llvm/test/Transforms/LoopVectorize/strict-fadd-interleave-only.ll (+2-2)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index fa9392b86c15b9..eae993e60008e4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1241,6 +1241,8 @@ class TargetTransformInfo {
   /// and the number of execution units in the CPU.
   unsigned getMaxInterleaveFactor(ElementCount VF) const;
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2.
   static OperandValueInfo getOperandInfo(const Value *V);
 
@@ -1999,6 +2001,9 @@ class TargetTransformInfo::Concept {
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
+
+  virtual ElementCount getMaxPredicateLength(ElementCount VF) const = 0;
+
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       OperandValueInfo Opd1Info, OperandValueInfo Opd2Info,
@@ -2622,6 +2627,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+
+  ElementCount getMaxPredicateLength(ElementCount VF) const override {
+    return Impl.getMaxPredicateLength(VF);
+  }
+
   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                             unsigned &JTSize,
                                             ProfileSummaryInfo *PSI,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 63c2ef8912b29c..1c4f2f963e89f3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -531,6 +531,8 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info, TTI::OperandValueInfo Opd2Info,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 42d8f74fd427fb..297068943a63c6 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -890,6 +890,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 5f933b4587843c..39bf57737890c8 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -816,6 +816,10 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
 
+ElementCount TargetTransformInfo::getMaxPredicateLength(ElementCount VF) const {
+  return TTIImpl->getMaxPredicateLength(VF);
+}
+
 TargetTransformInfo::OperandValueInfo
 TargetTransformInfo::getOperandInfo(const Value *V) {
   OperandValueKind OpInfo = OK_AnyValue;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8218960406ec13..d0afa0a5c65330 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1873,8 +1873,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
@@ -20507,6 +20507,61 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
+static SDValue tryCombineWhileLo(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const AArch64Subtarget *Subtarget) {
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
+  if (HalfSize < 2)
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  uint64_t OffLo, OffHi;
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Hi->getOperand(1).getNode(), OffHi))
+    return SDValue();
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
+    return SDValue();
+
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc DL(N);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  SDValue Idx = N->getOperand(1);
+  SDValue TC = N->getOperand(2);
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  }
+  auto R =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+  DCI.CombineTo(Lo, R.getValue(0));
+  DCI.CombineTo(Hi, R.getValue(1));
+
+  return SDValue(N, 0);
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -20837,6 +20892,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_ptest_last:
     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
                     AArch64CC::LAST_ACTIVE);
+  case Intrinsic::aarch64_sve_whilelo:
+    return tryCombineWhileLo(N, DCI, Subtarget);
   }
   return SDValue();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 22687b0e31c284..1b422969379d25 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1351,48 +1351,52 @@ static bool areCFlagsAccessedBetweenInstrs(
   return false;
 }
 
-/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
-/// operation which could set the flags in an identical manner
-bool AArch64InstrInfo::optimizePTestInstr(
-    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
-    const MachineRegisterInfo *MRI) const {
-  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
-  auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
-  bool OpChanged = false;
-
+std::pair<bool, unsigned>
+AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                                      MachineInstr *Pred,
+                                      const MachineRegisterInfo *MRI) const {
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
-  if (isPTrueOpcode(MaskOpcode) && (PredIsPTestLike || PredIsWhileLike) &&
-      getElementSizeForOpcode(MaskOpcode) ==
-          getElementSizeForOpcode(PredOpcode) &&
-      Mask->getOperand(1).getImm() == 31) {
+  if (PredIsWhileLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of a WHILEcc
+    // instruction and the condition is "any" since WHILcc does an implicit
+    // PTEST(ALL, PG) check and PG is always a subset of ALL.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return {true, 0};
+
     // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
     // redundant since WHILE performs an implicit PTEST with an all active
-    // mask. Must be an all active predicate of matching element size.
+    // mask.
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode))
+      return {true, 0};
+
+    return {false, 0};
+  }
+
+  if (PredIsPTestLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
+    // instruction that sets the flags as PTEST would and the condition is
+    // "any" since PG is always a subset of the governing predicate of the
+    // ptest-like instruction.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return {true, 0};
 
     // For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
-    // PTEST_LIKE instruction uses the same all active mask and the element
-    // size matches. If the PTEST has a condition of any then it is always
-    // redundant.
-    if (PredIsPTestLike) {
+    // the element size matches and either the PTEST_LIKE instruction uses
+    // the same all active mask or the condition is "any".
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode)) {
       auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PTestLikeMask && PTest->getOpcode() != AArch64::PTEST_PP_ANY)
-        return false;
+      if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+        return {true, 0};
     }
 
-    // Fallthough to simply remove the PTEST.
-  } else if ((Mask == Pred) && (PredIsPTestLike || PredIsWhileLike) &&
-             PTest->getOpcode() == AArch64::PTEST_PP_ANY) {
-    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
-    // instruction that sets the flags as PTEST would. This is only valid when
-    // the condition is any.
-
-    // Fallthough to simply remove the PTEST.
-  } else if (PredIsPTestLike) {
     // For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
     // flags are set based on the same mask 'PG', but PTEST_LIKE must operate
     // on 8-bit predicates like the PTEST.  Otherwise, for instructions like
@@ -1417,55 +1421,65 @@ bool AArch64InstrInfo::optimizePTestInstr(
     // identical regardless of element size.
     auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
     uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
-    if ((Mask != PTestLikeMask) ||
-        (PredElementSize != AArch64::ElementSizeB &&
-         PTest->getOpcode() != AArch64::PTEST_PP_ANY))
-      return false;
+    if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
+                                  PTest->getOpcode() == AArch64::PTEST_PP_ANY))
+      return {true, 0};
 
-    // Fallthough to simply remove the PTEST.
-  } else {
-    // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
-    // opcode so the PTEST becomes redundant.
-    switch (PredOpcode) {
-    case AArch64::AND_PPzPP:
-    case AArch64::BIC_PPzPP:
-    case AArch64::EOR_PPzPP:
-    case AArch64::NAND_PPzPP:
-    case AArch64::NOR_PPzPP:
-    case AArch64::ORN_PPzPP:
-    case AArch64::ORR_PPzPP:
-    case AArch64::BRKA_PPzP:
-    case AArch64::BRKPA_PPzPP:
-    case AArch64::BRKB_PPzP:
-    case AArch64::BRKPB_PPzPP:
-    case AArch64::RDFFR_PPz: {
-      // Check to see if our mask is the same. If not the resulting flag bits
-      // may be different and we can't remove the ptest.
-      auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PredMask)
-        return false;
-      break;
-    }
-    case AArch64::BRKN_PPzP: {
-      // BRKN uses an all active implicit mask to set flags unlike the other
-      // flag-setting instructions.
-      // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
-      if ((MaskOpcode != AArch64::PTRUE_B) ||
-          (Mask->getOperand(1).getImm() != 31))
-        return false;
-      break;
-    }
-    case AArch64::PTRUE_B:
-      // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
-      break;
-    default:
-      // Bail out if we don't recognize the input
-      return false;
-    }
+    return {false, 0};
+  }
 
-    NewOp = convertToFlagSettingOpc(PredOpcode);
-    OpChanged = true;
+  // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
+  // opcode so the PTEST becomes redundant.
+  switch (PredOpcode) {
+  case AArch64::AND_PPzPP:
+  case AArch64::BIC_PPzPP:
+  case AArch64::EOR_PPzPP:
+  case AArch64::NAND_PPzPP:
+  case AArch64::NOR_PPzPP:
+  case AArch64::ORN_PPzPP:
+  case AArch64::ORR_PPzPP:
+  case AArch64::BRKA_PPzP:
+  case AArch64::BRKPA_PPzPP:
+  case AArch64::BRKB_PPzP:
+  case AArch64::BRKPB_PPzPP:
+  case AArch64::RDFFR_PPz: {
+    // Check to see if our mask is the same. If not the resulting flag bits
+    // may be different and we can't remove the ptest.
+    auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (Mask != PredMask)
+      return {false, 0};
+    break;
   }
+  case AArch64::BRKN_PPzP: {
+    // BRKN uses an all active implicit mask to set flags unlike the other
+    // flag-setting instructions.
+    // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
+    if ((MaskOpcode != AArch64::PTRUE_B) ||
+        (Mask->getOperand(1).getImm() != 31))
+      return {false, 0};
+    break;
+  }
+  case AArch64::PTRUE_B:
+    // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
+    break;
+  default:
+    // Bail out if we don't recognize the input
+    return {false, 0};
+  }
+
+  return {true, convertToFlagSettingOpc(PredOpcode)};
+}
+
+/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
+/// operation which could set the flags in an identical manner
+bool AArch64InstrInfo::optimizePTestInstr(
+    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
+    const MachineRegisterInfo *MRI) const {
+  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
+  auto *Pred = MRI->getUniqueVRegDef(PredReg);
+  auto [canRemove, NewOp] = canRemovePTestInstr(PTest, Mask, Pred, MRI);
+  if (!canRemove)
+    return false;
 
   const TargetRegisterInfo *TRI = &getRegisterInfo();
 
@@ -1478,9 +1492,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
-  if (OpChanged) {
+  if (NewOp) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 2f10f80f4bdf70..7cc770a5b4eb49 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -432,6 +432,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
   bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
                           unsigned PredReg,
                           const MachineRegisterInfo *MRI) const;
+  std::pair<bool, unsigned>
+  canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                      MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
 };
 
 struct UsedNZCV {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ee7137b92445bb..66498817f55f73 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3362,6 +3362,15 @@ unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
 
+ElementCount AArch64TTIImpl::getMaxPredicateLength(ElementCount VF) const {
+  // Do not create masks bigger than `<vscale x 16 x i1>`.
+  unsigned N = ST->hasSVE() ? 16 : 0;
+  // Do not create masks that are more than twice the VF.
+  N = std::min(N, 2 * VF.getKnownMinValue());
+  return VF.isScalable() ? ElementCount::getScalable(N)
+                         : ElementCount::getFixed(N);
+}
+
 // For Falkor, we want to avoid having too many strided loads in a loop since
 // that can exhaust the HW prefetcher resources.  We adjust the unroller
 // MaxCount preference below to attempt to ensure unrolling doesn't create too
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index de39dea2be43e1..6501cc4a85e8d3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -157,6 +157,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index ece2a34f180cb4..a4e4bd8c2bb4b2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -196,6 +196,14 @@ class VPBuilder {
   VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
                       DebugLoc DL = {}, const Twine &Name = "");
 
+  VPValue *createGetActiveLaneMask(VPValue *IV, VPValue *TC, DebugLoc DL,
+                                   const Twine &Name = "") {
+    auto *ALM = new VPActiveLaneMaskRecipe(IV, TC, DL, Name);
+    if (BB)
+      BB->insert(ALM, InsertPt);
+    return ALM;
+  }
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 49bacb5ae6cc4e..6cf139775475c6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -597,6 +597,10 @@ class InnerLoopVectorizer {
   /// count of the original loop for both main loop and epilogue vectorization.
   void setTripCount(Value *TC) { TripCount = TC; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const {
+    return TTI->getMaxPredicateLength(VF);
+  }
+
 protected:
   friend class LoopVectorizationPlanner;
 
@@ -7525,7 +7529,8 @@ LoopVectorizationPlanner::executePlan(
   LLVM_DEBUG(BestVPlan.dump());
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+  VPTransformState State(BestVF, BestUF, TTI.getMaxPredicateLength(BestVF), LI,
+                         DT, ILV.Builder, &ILV, &BestVPlan,
                          OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 8ebd75da346546..27ff4c884cc566 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -214,12 +214,13 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
   return It;
 }
 
-VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
+VPTransformState::VPTransformState(ElementCount VF, unsigned UF,
+                                   ElementCount MaxPred, LoopInfo *LI,
                                    DominatorTree *DT, IRBui...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 5, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Momchil Velikov (momchil-velikov)

Changes

This patch refactors AArch64InstrInfo::optimizePTestInstr to simplify the convoluted conditions and control flow
and make it easier to add the optimisation in #81141


Patch is 474.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87802.diff

28 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+10)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+59-2)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+92-78)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.h (+3)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h (+8)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+6-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.cpp (+4-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+47-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (+29-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+86-21)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+8-10)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanValue.h (+1)
  • (modified) llvm/test/CodeGen/AArch64/active_lane_mask.ll (+1)
  • (added) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+395)
  • (added) llvm/test/CodeGen/AArch64/sve-wide-lane-mask.ll (+1069)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll (+40-55)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/pr73894.ll (+2-2)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/scalable-strict-fadd.ll (+842-830)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-tail-folding-unroll.ll (+165-157)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/sve-wide-lane-mask.ll (+656)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll (+48-54)
  • (modified) llvm/test/Transforms/LoopVectorize/ARM/tail-folding-prefer-flag.ll (+3-3)
  • (modified) llvm/test/Transforms/LoopVectorize/strict-fadd-interleave-only.ll (+2-2)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index fa9392b86c15b9..eae993e60008e4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1241,6 +1241,8 @@ class TargetTransformInfo {
   /// and the number of execution units in the CPU.
   unsigned getMaxInterleaveFactor(ElementCount VF) const;
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   /// Collect properties of V used in cost analysis, e.g. OP_PowerOf2.
   static OperandValueInfo getOperandInfo(const Value *V);
 
@@ -1999,6 +2001,9 @@ class TargetTransformInfo::Concept {
   virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;
 
   virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
+
+  virtual ElementCount getMaxPredicateLength(ElementCount VF) const = 0;
+
   virtual InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       OperandValueInfo Opd1Info, OperandValueInfo Opd2Info,
@@ -2622,6 +2627,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxInterleaveFactor(ElementCount VF) override {
     return Impl.getMaxInterleaveFactor(VF);
   }
+
+  ElementCount getMaxPredicateLength(ElementCount VF) const override {
+    return Impl.getMaxPredicateLength(VF);
+  }
+
   unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                             unsigned &JTSize,
                                             ProfileSummaryInfo *PSI,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 63c2ef8912b29c..1c4f2f963e89f3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -531,6 +531,8 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info, TTI::OperandValueInfo Opd2Info,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 42d8f74fd427fb..297068943a63c6 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -890,6 +890,8 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF) { return 1; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const { return VF; }
+
   InstructionCost getArithmeticInstrCost(
       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
       TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None},
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 5f933b4587843c..39bf57737890c8 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -816,6 +816,10 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
   return TTIImpl->getMaxInterleaveFactor(VF);
 }
 
+ElementCount TargetTransformInfo::getMaxPredicateLength(ElementCount VF) const {
+  return TTIImpl->getMaxPredicateLength(VF);
+}
+
 TargetTransformInfo::OperandValueInfo
 TargetTransformInfo::getOperandInfo(const Value *V) {
   OperandValueKind OpInfo = OK_AnyValue;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8218960406ec13..d0afa0a5c65330 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1873,8 +1873,8 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
                                                           EVT OpVT) const {
-  // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
-  if (!Subtarget->hasSVE())
+  // Only SVE/SME has a 1:1 mapping from intrinsic -> instruction (whilelo).
+  if (!Subtarget->hasSVEorSME())
     return true;
 
   // We can only support legal predicate result types. We can use the SVE
@@ -20507,6 +20507,61 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   return SDValue();
 }
 
+static SDValue tryCombineWhileLo(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 const AArch64Subtarget *Subtarget) {
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  if (!Subtarget->hasSVE2p1() && !Subtarget->hasSME2())
+    return SDValue();
+
+  if (!N->hasNUsesOfValue(2, 0))
+    return SDValue();
+
+  const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
+  if (HalfSize < 2)
+    return SDValue();
+
+  auto It = N->use_begin();
+  SDNode *Lo = *It++;
+  SDNode *Hi = *It;
+
+  uint64_t OffLo, OffHi;
+  if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Lo->getOperand(1).getNode(), OffLo) ||
+      Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      !isIntImmediate(Hi->getOperand(1).getNode(), OffHi))
+    return SDValue();
+
+  if (OffLo > OffHi) {
+    std::swap(Lo, Hi);
+    std::swap(OffLo, OffHi);
+  }
+
+  if (OffLo != 0 || OffHi != HalfSize)
+    return SDValue();
+
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc DL(N);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+  SDValue Idx = N->getOperand(1);
+  SDValue TC = N->getOperand(2);
+  if (Idx.getValueType() != MVT::i64) {
+    Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+    TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+  }
+  auto R =
+      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+                  {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+  DCI.CombineTo(Lo, R.getValue(0));
+  DCI.CombineTo(Hi, R.getValue(1));
+
+  return SDValue(N, 0);
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -20837,6 +20892,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_ptest_last:
     return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
                     AArch64CC::LAST_ACTIVE);
+  case Intrinsic::aarch64_sve_whilelo:
+    return tryCombineWhileLo(N, DCI, Subtarget);
   }
   return SDValue();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 22687b0e31c284..1b422969379d25 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1351,48 +1351,52 @@ static bool areCFlagsAccessedBetweenInstrs(
   return false;
 }
 
-/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
-/// operation which could set the flags in an identical manner
-bool AArch64InstrInfo::optimizePTestInstr(
-    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
-    const MachineRegisterInfo *MRI) const {
-  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
-  auto *Pred = MRI->getUniqueVRegDef(PredReg);
-  auto NewOp = Pred->getOpcode();
-  bool OpChanged = false;
-
+std::pair<bool, unsigned>
+AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                                      MachineInstr *Pred,
+                                      const MachineRegisterInfo *MRI) const {
   unsigned MaskOpcode = Mask->getOpcode();
   unsigned PredOpcode = Pred->getOpcode();
   bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
   bool PredIsWhileLike = isWhileOpcode(PredOpcode);
 
-  if (isPTrueOpcode(MaskOpcode) && (PredIsPTestLike || PredIsWhileLike) &&
-      getElementSizeForOpcode(MaskOpcode) ==
-          getElementSizeForOpcode(PredOpcode) &&
-      Mask->getOperand(1).getImm() == 31) {
+  if (PredIsWhileLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of a WHILEcc
+    // instruction and the condition is "any" since WHILcc does an implicit
+    // PTEST(ALL, PG) check and PG is always a subset of ALL.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return {true, 0};
+
     // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
     // redundant since WHILE performs an implicit PTEST with an all active
-    // mask. Must be an all active predicate of matching element size.
+    // mask.
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode))
+      return {true, 0};
+
+    return {false, 0};
+  }
+
+  if (PredIsPTestLike) {
+    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
+    // instruction that sets the flags as PTEST would and the condition is
+    // "any" since PG is always a subset of the governing predicate of the
+    // ptest-like instruction.
+    if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+      return {true, 0};
 
     // For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
-    // PTEST_LIKE instruction uses the same all active mask and the element
-    // size matches. If the PTEST has a condition of any then it is always
-    // redundant.
-    if (PredIsPTestLike) {
+    // the element size matches and either the PTEST_LIKE instruction uses
+    // the same all active mask or the condition is "any".
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
+        getElementSizeForOpcode(MaskOpcode) ==
+            getElementSizeForOpcode(PredOpcode)) {
       auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PTestLikeMask && PTest->getOpcode() != AArch64::PTEST_PP_ANY)
-        return false;
+      if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
+        return {true, 0};
     }
 
-    // Fallthough to simply remove the PTEST.
-  } else if ((Mask == Pred) && (PredIsPTestLike || PredIsWhileLike) &&
-             PTest->getOpcode() == AArch64::PTEST_PP_ANY) {
-    // For PTEST(PG, PG), PTEST is redundant when PG is the result of an
-    // instruction that sets the flags as PTEST would. This is only valid when
-    // the condition is any.
-
-    // Fallthough to simply remove the PTEST.
-  } else if (PredIsPTestLike) {
     // For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
     // flags are set based on the same mask 'PG', but PTEST_LIKE must operate
     // on 8-bit predicates like the PTEST.  Otherwise, for instructions like
@@ -1417,55 +1421,65 @@ bool AArch64InstrInfo::optimizePTestInstr(
     // identical regardless of element size.
     auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
     uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
-    if ((Mask != PTestLikeMask) ||
-        (PredElementSize != AArch64::ElementSizeB &&
-         PTest->getOpcode() != AArch64::PTEST_PP_ANY))
-      return false;
+    if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
+                                  PTest->getOpcode() == AArch64::PTEST_PP_ANY))
+      return {true, 0};
 
-    // Fallthough to simply remove the PTEST.
-  } else {
-    // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
-    // opcode so the PTEST becomes redundant.
-    switch (PredOpcode) {
-    case AArch64::AND_PPzPP:
-    case AArch64::BIC_PPzPP:
-    case AArch64::EOR_PPzPP:
-    case AArch64::NAND_PPzPP:
-    case AArch64::NOR_PPzPP:
-    case AArch64::ORN_PPzPP:
-    case AArch64::ORR_PPzPP:
-    case AArch64::BRKA_PPzP:
-    case AArch64::BRKPA_PPzPP:
-    case AArch64::BRKB_PPzP:
-    case AArch64::BRKPB_PPzPP:
-    case AArch64::RDFFR_PPz: {
-      // Check to see if our mask is the same. If not the resulting flag bits
-      // may be different and we can't remove the ptest.
-      auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
-      if (Mask != PredMask)
-        return false;
-      break;
-    }
-    case AArch64::BRKN_PPzP: {
-      // BRKN uses an all active implicit mask to set flags unlike the other
-      // flag-setting instructions.
-      // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
-      if ((MaskOpcode != AArch64::PTRUE_B) ||
-          (Mask->getOperand(1).getImm() != 31))
-        return false;
-      break;
-    }
-    case AArch64::PTRUE_B:
-      // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
-      break;
-    default:
-      // Bail out if we don't recognize the input
-      return false;
-    }
+    return {false, 0};
+  }
 
-    NewOp = convertToFlagSettingOpc(PredOpcode);
-    OpChanged = true;
+  // If OP in PTEST(PG, OP(PG, ...)) has a flag-setting variant change the
+  // opcode so the PTEST becomes redundant.
+  switch (PredOpcode) {
+  case AArch64::AND_PPzPP:
+  case AArch64::BIC_PPzPP:
+  case AArch64::EOR_PPzPP:
+  case AArch64::NAND_PPzPP:
+  case AArch64::NOR_PPzPP:
+  case AArch64::ORN_PPzPP:
+  case AArch64::ORR_PPzPP:
+  case AArch64::BRKA_PPzP:
+  case AArch64::BRKPA_PPzPP:
+  case AArch64::BRKB_PPzP:
+  case AArch64::BRKPB_PPzPP:
+  case AArch64::RDFFR_PPz: {
+    // Check to see if our mask is the same. If not the resulting flag bits
+    // may be different and we can't remove the ptest.
+    auto *PredMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
+    if (Mask != PredMask)
+      return {false, 0};
+    break;
   }
+  case AArch64::BRKN_PPzP: {
+    // BRKN uses an all active implicit mask to set flags unlike the other
+    // flag-setting instructions.
+    // PTEST(PTRUE_B(31), BRKN(PG, A, B)) -> BRKNS(PG, A, B).
+    if ((MaskOpcode != AArch64::PTRUE_B) ||
+        (Mask->getOperand(1).getImm() != 31))
+      return {false, 0};
+    break;
+  }
+  case AArch64::PTRUE_B:
+    // PTEST(OP=PTRUE_B(A), OP) -> PTRUES_B(A)
+    break;
+  default:
+    // Bail out if we don't recognize the input
+    return {false, 0};
+  }
+
+  return {true, convertToFlagSettingOpc(PredOpcode)};
+}
+
+/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
+/// operation which could set the flags in an identical manner
+bool AArch64InstrInfo::optimizePTestInstr(
+    MachineInstr *PTest, unsigned MaskReg, unsigned PredReg,
+    const MachineRegisterInfo *MRI) const {
+  auto *Mask = MRI->getUniqueVRegDef(MaskReg);
+  auto *Pred = MRI->getUniqueVRegDef(PredReg);
+  auto [canRemove, NewOp] = canRemovePTestInstr(PTest, Mask, Pred, MRI);
+  if (!canRemove)
+    return false;
 
   const TargetRegisterInfo *TRI = &getRegisterInfo();
 
@@ -1478,9 +1492,9 @@ bool AArch64InstrInfo::optimizePTestInstr(
   // as they are prior to PTEST. Sometimes this requires the tested PTEST
   // operand to be replaced with an equivalent instruction that also sets the
   // flags.
-  Pred->setDesc(get(NewOp));
   PTest->eraseFromParent();
-  if (OpChanged) {
+  if (NewOp) {
+    Pred->setDesc(get(NewOp));
     bool succeeded = UpdateOperandRegClass(*Pred);
     (void)succeeded;
     assert(succeeded && "Operands have incompatible register classes!");
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index 2f10f80f4bdf70..7cc770a5b4eb49 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -432,6 +432,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
   bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
                           unsigned PredReg,
                           const MachineRegisterInfo *MRI) const;
+  std::pair<bool, unsigned>
+  canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
+                      MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
 };
 
 struct UsedNZCV {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ee7137b92445bb..66498817f55f73 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3362,6 +3362,15 @@ unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
 
+ElementCount AArch64TTIImpl::getMaxPredicateLength(ElementCount VF) const {
+  // Do not create masks bigger than `<vscale x 16 x i1>`.
+  unsigned N = ST->hasSVE() ? 16 : 0;
+  // Do not create masks that are more than twice the VF.
+  N = std::min(N, 2 * VF.getKnownMinValue());
+  return VF.isScalable() ? ElementCount::getScalable(N)
+                         : ElementCount::getFixed(N);
+}
+
 // For Falkor, we want to avoid having too many strided loads in a loop since
 // that can exhaust the HW prefetcher resources.  We adjust the unroller
 // MaxCount preference below to attempt to ensure unrolling doesn't create too
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index de39dea2be43e1..6501cc4a85e8d3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -157,6 +157,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const;
+
   bool prefersVectorizedAddressing() const;
 
   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index ece2a34f180cb4..a4e4bd8c2bb4b2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -196,6 +196,14 @@ class VPBuilder {
   VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B,
                       DebugLoc DL = {}, const Twine &Name = "");
 
+  VPValue *createGetActiveLaneMask(VPValue *IV, VPValue *TC, DebugLoc DL,
+                                   const Twine &Name = "") {
+    auto *ALM = new VPActiveLaneMaskRecipe(IV, TC, DL, Name);
+    if (BB)
+      BB->insert(ALM, InsertPt);
+    return ALM;
+  }
+
   //===--------------------------------------------------------------------===//
   // RAII helpers.
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 49bacb5ae6cc4e..6cf139775475c6 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -597,6 +597,10 @@ class InnerLoopVectorizer {
   /// count of the original loop for both main loop and epilogue vectorization.
   void setTripCount(Value *TC) { TripCount = TC; }
 
+  ElementCount getMaxPredicateLength(ElementCount VF) const {
+    return TTI->getMaxPredicateLength(VF);
+  }
+
 protected:
   friend class LoopVectorizationPlanner;
 
@@ -7525,7 +7529,8 @@ LoopVectorizationPlanner::executePlan(
   LLVM_DEBUG(BestVPlan.dump());
 
   // Perform the actual loop transformation.
-  VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan,
+  VPTransformState State(BestVF, BestUF, TTI.getMaxPredicateLength(BestVF), LI,
+                         DT, ILV.Builder, &ILV, &BestVPlan,
                          OrigLoop->getHeader()->getContext());
 
   // 0. Generate SCEV-dependent code into the preheader, including TripCount,
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index 8ebd75da346546..27ff4c884cc566 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -214,12 +214,13 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
   return It;
 }
 
-VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
+VPTransformState::VPTransformState(ElementCount VF, unsigned UF,
+                                   ElementCount MaxPred, LoopInfo *LI,
                                    DominatorTree *DT, IRBui...
[truncated]

Copy link

github-actions bot commented Apr 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@momchil-velikov momchil-velikov force-pushed the refactor-redundant-ptest branch from 17fe202 to 617ae10 Compare April 18, 2024 17:00
@momchil-velikov momchil-velikov force-pushed the refactor-redundant-ptest branch from 617ae10 to bd6adbb Compare June 3, 2024 14:57
Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can remove the non-relevant loop vectorisation commit then I'll review the PR properly.

Broadly it looks good but perhaps it's just me but seeing return {true, 0}; and return {true, 0}; looks weird despite understanding their meaning.

What do you think to canRemovePTestInstr returning std::optional<unsigned> with the failure cases returning nullopt and positive cases returning a real opcode, be that PredOpcode or convertToFlagSettingOpc(PredOpcode) depending on whether an opcode replacement is necessary?

@momchil-velikov momchil-velikov force-pushed the refactor-redundant-ptest branch from bd6adbb to 6eded29 Compare June 12, 2024 10:27
@momchil-velikov
Copy link
Collaborator Author

What do you think to canRemovePTestInstr returning std::optional<unsigned> ...

It'd be nicer, in principle, however this patch is extracting the NFC parts from #81141
where the pair becomes a tuple and gets one more member.

@momchil-velikov
Copy link
Collaborator Author

#87802 and #81141 rebased on top of main, no longer descendants to the wide active lane mask patch (#81140)

@momchil-velikov
Copy link
Collaborator Author

What do you think to canRemovePTestInstr returning std::optional<unsigned> ...

It'd be nicer, in principle, however this patch is extracting the NFC parts from #81141 where the pair becomes a tuple and gets one more member.

Actually ...
I'll have a second look into using std::optional as it results in better generated code ...

@paulwalker-arm
Copy link
Collaborator

A quick look at #81141 suggests at worst you'd just move from the std::optional<unsigned> this patch would introduce to std::optional<std::pair<unsigned,MachineInstr*>>.

Change-Id: I63ff6f4a7f90cd584508cbaa8bba8a39a8ca3f56
@momchil-velikov momchil-velikov force-pushed the refactor-redundant-ptest branch from 6eded29 to 2f5e80c Compare June 12, 2024 14:20
@momchil-velikov
Copy link
Collaborator Author

A quick look at #81141 suggests at worst you'd just move from the std::optional<unsigned> this patch would introduce to std::optional<std::pair<unsigned,MachineInstr*>>.

Done.

Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer std::nullopt over {} but I see both styles used so will not worry about it.

@momchil-velikov momchil-velikov merged commit 6ec02f7 into llvm:main Jun 18, 2024
7 checks passed
@momchil-velikov momchil-velikov deleted the refactor-redundant-ptest branch November 13, 2024 09:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants