Skip to content

[AArch64] Improve index selection for histograms #111150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 22, 2024

Conversation

JamesChesterman
Copy link
Contributor

Search for extends to the index used in a histogram operation then perform a truncate on it. This avoids the need to split the instruction in two.

Search for extends to the index used in a histogram operation
then perform a truncate on it. This avoids the need to split the
instruction in two.
@llvmbot
Copy link
Member

llvmbot commented Oct 4, 2024

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: James Chesterman (JamesChesterman)

Changes

Search for extends to the index used in a histogram operation then perform a truncate on it. This avoids the need to split the instruction in two.


Full diff: https://github.com/llvm/llvm-project/pull/111150.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+34-3)
  • (modified) llvm/test/CodeGen/AArch64/sve2-histcnt.ll (+73)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 48e1b96d841efb..545d5b59c64562 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1114,7 +1114,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                        ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
                        ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});
 
-  setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER});
+  setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
 
   setTargetDAGCombine(ISD::FP_EXTEND);
 
@@ -24079,12 +24079,42 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
 
 static SDValue performMaskedGatherScatterCombine(
     SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
-  MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
-  assert(MGS && "Can only combine gather load or scatter store nodes");
+  MaskedHistogramSDNode *HG;
+  MaskedGatherScatterSDNode *MGS;
+  if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+    HG = cast<MaskedHistogramSDNode>(N);
+  } else {
+    MGS = cast<MaskedGatherScatterSDNode>(N);
+  }
+  assert((HG || MGS) &&
+         "Can only combine gather load, scatter store or histogram nodes");
 
   if (!DCI.isBeforeLegalize())
     return SDValue();
 
+  if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+    SDLoc DL(HG);
+    SDValue Index = HG->getIndex();
+    if (ISD::isExtOpcode(Index->getOpcode())) {
+      SDValue Chain = HG->getChain();
+      SDValue Inc = HG->getInc();
+      SDValue Mask = HG->getMask();
+      SDValue BasePtr = HG->getBasePtr();
+      SDValue Scale = HG->getScale();
+      SDValue IntID = HG->getIntID();
+      EVT MemVT = HG->getMemoryVT();
+      MachineMemOperand *MMO = HG->getMemOperand();
+      ISD::MemIndexType IndexType = HG->getIndexType();
+      SDValue ExtOp = Index.getOperand(0);
+      auto SrcType = ExtOp.getValueType();
+      auto TruncatedIndex = DAG.getAnyExtOrTrunc(Index, DL, SrcType);
+      SDValue Ops[] = {Chain, Inc, Mask, BasePtr, TruncatedIndex, Scale, IntID};
+      return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
+                                    MMO, IndexType);
+    }
+    return SDValue();
+  }
+
   SDLoc DL(MGS);
   SDValue Chain = MGS->getChain();
   SDValue Scale = MGS->getScale();
@@ -26277,6 +26307,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performMSTORECombine(N, DCI, DAG, Subtarget);
   case ISD::MGATHER:
   case ISD::MSCATTER:
+  case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return performMaskedGatherScatterCombine(N, DCI, DAG);
   case ISD::FP_EXTEND:
     return performFPExtendCombine(N, DAG, DCI, Subtarget);
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index dd0b9639a8fc2f..42fff1ec7c532f 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -267,5 +267,78 @@ define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %i
   ret void
 }
 
+define void @histogram_i32_extend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_extend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+define void @histogram_i32_8_lane_extend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_extend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.s, w1
+; CHECK-NEXT:    ptrue p2.s
+; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT:    st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = zext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+  ret void
+}
+define void @histogram_i32_sextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_i32_sextend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #1 // =0x1
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+define void @histogram_i32_8_lane_sextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_sextend:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    mov z4.s, w1
+; CHECK-NEXT:    ptrue p2.s
+; CHECK-NEXT:    histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT:    ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT:    st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT:    histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT:    ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    ret
+  %extended = sext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+  %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+  call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+  ret void
+}
+
 
 attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }

Copy link

github-actions bot commented Oct 4, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few general comments 🙂

@JamesChesterman JamesChesterman marked this pull request as draft October 9, 2024 15:45
Scatter. Now also removes the instruction when its mask is zero.
@JamesChesterman JamesChesterman marked this pull request as ready for review October 16, 2024 10:49
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Oct 16, 2024
Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thank you.

Copy link
Collaborator

@huntergr-arm huntergr-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JamesChesterman JamesChesterman merged commit 11c8188 into llvm:main Oct 22, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants