-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[AArch64] Improve index selection for histograms #111150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Search for extends to the index used in a histogram operation then perform a truncate on it. This avoids the need to split the instruction in two.
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-aarch64 Author: James Chesterman (JamesChesterman) ChangesSearch for extends to the index used in a histogram operation then perform a truncate on it. This avoids the need to split the instruction in two. Full diff: https://github.com/llvm/llvm-project/pull/111150.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 48e1b96d841efb..545d5b59c64562 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1114,7 +1114,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});
- setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER});
+ setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
setTargetDAGCombine(ISD::FP_EXTEND);
@@ -24079,12 +24079,42 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
static SDValue performMaskedGatherScatterCombine(
SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
- MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
- assert(MGS && "Can only combine gather load or scatter store nodes");
+ MaskedHistogramSDNode *HG;
+ MaskedGatherScatterSDNode *MGS;
+ if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+ HG = cast<MaskedHistogramSDNode>(N);
+ } else {
+ MGS = cast<MaskedGatherScatterSDNode>(N);
+ }
+ assert((HG || MGS) &&
+ "Can only combine gather load, scatter store or histogram nodes");
if (!DCI.isBeforeLegalize())
return SDValue();
+ if (N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM) {
+ SDLoc DL(HG);
+ SDValue Index = HG->getIndex();
+ if (ISD::isExtOpcode(Index->getOpcode())) {
+ SDValue Chain = HG->getChain();
+ SDValue Inc = HG->getInc();
+ SDValue Mask = HG->getMask();
+ SDValue BasePtr = HG->getBasePtr();
+ SDValue Scale = HG->getScale();
+ SDValue IntID = HG->getIntID();
+ EVT MemVT = HG->getMemoryVT();
+ MachineMemOperand *MMO = HG->getMemOperand();
+ ISD::MemIndexType IndexType = HG->getIndexType();
+ SDValue ExtOp = Index.getOperand(0);
+ auto SrcType = ExtOp.getValueType();
+ auto TruncatedIndex = DAG.getAnyExtOrTrunc(Index, DL, SrcType);
+ SDValue Ops[] = {Chain, Inc, Mask, BasePtr, TruncatedIndex, Scale, IntID};
+ return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops,
+ MMO, IndexType);
+ }
+ return SDValue();
+ }
+
SDLoc DL(MGS);
SDValue Chain = MGS->getChain();
SDValue Scale = MGS->getScale();
@@ -26277,6 +26307,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performMSTORECombine(N, DCI, DAG, Subtarget);
case ISD::MGATHER:
case ISD::MSCATTER:
+ case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return performMaskedGatherScatterCombine(N, DCI, DAG);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index dd0b9639a8fc2f..42fff1ec7c532f 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -267,5 +267,78 @@ define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %i
ret void
}
+define void @histogram_i32_extend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_extend:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, #1 // =0x1
+; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: ret
+ %extended = zext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+ %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+ ret void
+}
+define void @histogram_i32_8_lane_extend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_extend:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: mov z4.s, w1
+; CHECK-NEXT: ptrue p2.s
+; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: ret
+ %extended = zext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+ %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+ call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+ ret void
+}
+define void @histogram_i32_sextend(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0{
+; CHECK-LABEL: histogram_i32_sextend:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, #1 // =0x1
+; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: ret
+ %extended = sext <vscale x 4 x i32> %indices to <vscale x 4 x i64>
+ %buckets = getelementptr i32, ptr %base, <vscale x 4 x i64> %extended
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i32(<vscale x 4 x ptr> %buckets, i32 1, <vscale x 4 x i1> %mask)
+ ret void
+}
+define void @histogram_i32_8_lane_sextend(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i32_8_lane_sextend:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpklo p1.h, p0.b
+; CHECK-NEXT: mov z4.s, w1
+; CHECK-NEXT: ptrue p2.s
+; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
+; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
+; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
+; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
+; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
+; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT: ret
+ %extended = sext <vscale x 8 x i32> %indices to <vscale x 8 x i64>
+ %buckets = getelementptr i32, ptr %base, <vscale x 8 x i64> %extended
+ call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
+ ret void
+}
+
attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few general comments 🙂
Scatter. Now also removes the instruction when its mask is zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Search for extends to the index used in a histogram operation then perform a truncate on it. This avoids the need to split the instruction in two.