From 1c594a400929363677fb7850f1055929ed7951c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= Date: Sun, 12 May 2024 23:08:25 +0200 Subject: [PATCH 1/2] [GlobalIsel] combine extract vector element scalarize compares extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index) --- .../llvm/CodeGen/GlobalISel/CombinerHelper.h | 23 +++ .../include/llvm/Target/GlobalISel/Combine.td | 19 ++- .../GlobalISel/CombinerHelperVectorOps.cpp | 159 ++++++++++++++++++ .../CodeGen/AArch64/extract-vector-elt.ll | 128 ++++++++++++++ 4 files changed, 328 insertions(+), 1 deletion(-) diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h index ecaece8b68342..6edb3f9cd2e89 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h @@ -1,3 +1,4 @@ + //===-- llvm/CodeGen/GlobalISel/CombinerHelper.h --------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. @@ -866,6 +867,16 @@ class CombinerHelper { /// Combine insert vector element OOB. bool matchInsertVectorElementOOB(MachineInstr &MI, BuildFnTy &MatchInfo); + /// Combine extract vector element with a compare on the vector + /// register. + bool matchExtractVectorElementWithICmp(const MachineOperand &MO, + BuildFnTy &MatchInfo); + + /// Combine extract vector element with a compare on the vector + /// register. + bool matchExtractVectorElementWithFCmp(const MachineOperand &MO, + BuildFnTy &MatchInfo); + private: /// Checks for legality of an indexed variant of \p LdSt. bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const; @@ -981,6 +992,18 @@ class CombinerHelper { // Simplify (cmp cc0 x, y) (&& or ||) (cmp cc1 x, y) -> cmp cc2 x, y. bool tryFoldLogicOfFCmps(GLogicalBinOp *Logic, BuildFnTy &MatchInfo); + + /// Return true if the register \p Src is cheaper to scalarize than it is to + /// leave as a vector operation. If the extract index \p Index is a constant + /// integer then some operations may be cheap to scalarize. The depth \p Depth + /// prevents arbitrary recursion. + bool isCheapToScalarize(Register Src, const std::optional &Index, + unsigned Depth = 0); + + /// Return true if \p Src is def'd by a operation of type vector that is + /// constant at offset \p Index. \p Depth limits arbitrary recursion into look + /// through vector operations. + bool isConstantAtOffset(Register Src, const APInt &Index, unsigned Depth = 0); }; } // namespace llvm diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td index 98d266c8c0b4f..3c71c2a25b2d9 100644 --- a/llvm/include/llvm/Target/GlobalISel/Combine.td +++ b/llvm/include/llvm/Target/GlobalISel/Combine.td @@ -1591,6 +1591,20 @@ def insert_vector_elt_oob : GICombineRule< [{ return Helper.matchInsertVectorElementOOB(*${root}, ${matchinfo}); }]), (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>; +def extract_vector_element_icmp : GICombineRule< + (defs root:$root, build_fn_matchinfo:$matchinfo), + (match (G_ICMP $src, $pred, $lhs, $rhs), + (G_EXTRACT_VECTOR_ELT $root, $src, $idx), + [{ return Helper.matchExtractVectorElementWithICmp(${root}, ${matchinfo}); }]), + (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>; + +def extract_vector_element_fcmp : GICombineRule< + (defs root:$root, build_fn_matchinfo:$matchinfo), + (match (G_FCMP $fsrc, $fpred, $flhs, $frhs), + (G_EXTRACT_VECTOR_ELT $root, $fsrc, $fidx), + [{ return Helper.matchExtractVectorElementWithFCmp(${root}, ${matchinfo}); }]), + (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>; + // match_extract_of_element and insert_vector_elt_oob must be the first! def vector_ops_combines: GICombineGroup<[ match_extract_of_element_undef_vector, @@ -1624,6 +1638,8 @@ extract_vector_element_build_vector_trunc7, extract_vector_element_build_vector_trunc8, extract_vector_element_freeze, extract_vector_element_shuffle_vector, +extract_vector_element_icmp, +extract_vector_element_fcmp, insert_vector_element_extract_vector_element ]>; @@ -1706,7 +1722,8 @@ def all_combines : GICombineGroup<[trivial_combines, vector_ops_combines, sub_add_reg, select_to_minmax, redundant_binop_in_equality, fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors, combine_concat_vector, double_icmp_zero_and_or_combine, match_addos, - sext_trunc, zext_trunc, combine_shuffle_concat]>; + sext_trunc, zext_trunc, combine_shuffle_concat +]>; // A combine group used to for prelegalizer combiners at -O0. The combines in // this group have been selected based on experiments to balance code size and diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp index 21b1eb2628174..64b39e3f82e65 100644 --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp @@ -453,3 +453,162 @@ bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI, return false; } + +bool CombinerHelper::isConstantAtOffset(Register Src, const APInt &Index, + unsigned Depth) { + assert(MRI.getType(Src).isVector() && "expected a vector as input"); + if (Depth == 2) + return false; + + // We use the look through variant for higher hit rate and to increase the + // likelyhood of constant folding. The actual value is ignored. We only test + // *whether* there is a constant. + + MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI); + + // If Src is def'd by build vector, then we check the constness at the offset. + if (auto *Build = dyn_cast(SrcMI)) + return getAnyConstantVRegValWithLookThrough( + Build->getSourceReg(Index.getZExtValue()), MRI) + .has_value(); + + // For concat and shuffle vectors, we could recurse. + // FIXME concat vectors + // FIXME shuffle vectors + // FIXME unary ops + // FIXME insert vector element + // FIXME subvector + + return false; +} + +bool CombinerHelper::isCheapToScalarize(Register Src, + const std::optional &Index, + unsigned Depth) { + assert(MRI.getType(Src).isVector() && "expected a vector as input"); + + if (Depth >= 2) + return false; + + MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI); + + // If Src is def'd by a binary operator, + // then scalarizing the op is cheap when one of its operands is cheap to + // scalarize. + if (auto *BinOp = dyn_cast(SrcMI)) + if (MRI.hasOneNonDBGUse(BinOp->getReg(0))) + if (isCheapToScalarize(BinOp->getLHSReg(), Index, Depth + 1) || + isCheapToScalarize(BinOp->getRHSReg(), Index, Depth + 1)) + return true; + + // If Src is def'd by a compare, + // then scalarizing the cmp is cheap when one of its operands is cheap to + // scalarize. + if (auto *Cmp = dyn_cast(SrcMI)) + if (MRI.hasOneNonDBGUse(Cmp->getReg(0))) + if (isCheapToScalarize(Cmp->getLHSReg(), Index, Depth + 1) || + isCheapToScalarize(Cmp->getRHSReg(), Index, Depth + 1)) + return true; + + // FIXME: unary operator + // FIXME: casts + // FIXME: loads + // FIXME: subvector + + if (Index) + // If Index is constant, then Src is cheap to scalarize when it is constant + // at offset Index. + return isConstantAtOffset(Src, *Index, Depth); + + return false; +} + +bool CombinerHelper::matchExtractVectorElementWithICmp(const MachineOperand &MO, + BuildFnTy &MatchInfo) { + GExtractVectorElement *Extract = + cast(MRI.getVRegDef(MO.getReg())); + + Register Vector = Extract->getVectorReg(); + + GICmp *Cmp = cast(MRI.getVRegDef(Vector)); + + std::optional MaybeIndex = + getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI); + std::optional IndexC = std::nullopt; + + if (MaybeIndex) + IndexC = MaybeIndex->Value; + + if (!isCheapToScalarize(Vector, IndexC)) + return false; + + if (!MRI.hasOneNonDBGUse(Cmp->getReg(0))) + return false; + + Register Dst = Extract->getReg(0); + LLT DstTy = MRI.getType(Dst); + LLT IdxTy = MRI.getType(Extract->getIndexReg()); + LLT VectorTy = MRI.getType(Cmp->getLHSReg()); + LLT ExtractDstTy = VectorTy.getScalarType(); + + if (!isLegalOrBeforeLegalizer( + {TargetOpcode::G_ICMP, {DstTy, ExtractDstTy}}) || + !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT, + {ExtractDstTy, VectorTy, IdxTy}})) + return false; + + MatchInfo = [=](MachineIRBuilder &B) { + auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(), + Extract->getIndexReg()); + auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(), + Extract->getIndexReg()); + B.buildICmp(Cmp->getCond(), Dst, LHS, RHS); + }; + + return true; +} + +bool CombinerHelper::matchExtractVectorElementWithFCmp(const MachineOperand &MO, + BuildFnTy &MatchInfo) { + GExtractVectorElement *Extract = + cast(MRI.getVRegDef(MO.getReg())); + + Register Vector = Extract->getVectorReg(); + + GFCmp *Cmp = cast(MRI.getVRegDef(Vector)); + + std::optional MaybeIndex = + getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI); + std::optional IndexC = std::nullopt; + + if (MaybeIndex) + IndexC = MaybeIndex->Value; + + if (!isCheapToScalarize(Vector, IndexC)) + return false; + + if (!MRI.hasOneNonDBGUse(Cmp->getReg(0))) + return false; + + Register Dst = Extract->getReg(0); + LLT DstTy = MRI.getType(Dst); + LLT IdxTy = MRI.getType(Extract->getIndexReg()); + LLT VectorTy = MRI.getType(Cmp->getLHSReg()); + LLT ExtractDstTy = VectorTy.getScalarType(); + + if (!isLegalOrBeforeLegalizer( + {TargetOpcode::G_FCMP, {DstTy, ExtractDstTy}}) || + !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT, + {ExtractDstTy, VectorTy, IdxTy}})) + return false; + + MatchInfo = [=](MachineIRBuilder &B) { + auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(), + Extract->getIndexReg()); + auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(), + Extract->getIndexReg()); + B.buildFCmp(Cmp->getCond(), Dst, LHS, RHS, Cmp->getFlags()); + }; + + return true; +} diff --git a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll index 0481d997d24fa..42fe5e82cb7de 100644 --- a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll +++ b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll @@ -1100,4 +1100,132 @@ ret: ret i32 %3 } +define i32 @extract_v4float_fcmp_const_no_zext(<4 x float> %a, <4 x float> %b, i32 %c) { +; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: fcmeq v0.4s, v0.4s, v0.4s +; CHECK-SD-NEXT: mvn v0.16b, v0.16b +; CHECK-SD-NEXT: xtn v0.4h, v0.4s +; CHECK-SD-NEXT: umov w8, v0.h[1] +; CHECK-SD-NEXT: and w0, w8, #0x1 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: mov s0, v0.s[1] +; CHECK-GI-NEXT: fmov s1, #1.00000000 +; CHECK-GI-NEXT: fcmp s0, s1 +; CHECK-GI-NEXT: cset w0, vs +; CHECK-GI-NEXT: ret +entry: + %vector = fcmp uno <4 x float> %a, + %d = extractelement <4 x i1> %vector, i32 1 + %z = zext i1 %d to i32 + ret i32 %z +} +define i32 @extract_v4i32_icmp_const_no_zext(<4 x i32> %a, <4 x i32> %b, i32 %c) { +; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: adrp x8, .LCPI43_0 +; CHECK-SD-NEXT: ldr q1, [x8, :lo12:.LCPI43_0] +; CHECK-SD-NEXT: cmge v0.4s, v1.4s, v0.4s +; CHECK-SD-NEXT: xtn v0.4h, v0.4s +; CHECK-SD-NEXT: umov w8, v0.h[1] +; CHECK-SD-NEXT: and w0, w8, #0x1 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: mov s0, v0.s[1] +; CHECK-GI-NEXT: fmov w8, s0 +; CHECK-GI-NEXT: cmp w8, #8 +; CHECK-GI-NEXT: cset w0, le +; CHECK-GI-NEXT: ret +entry: + %vector = icmp sle <4 x i32> %a, + %d = extractelement <4 x i1> %vector, i32 1 + %z = zext i1 %d to i32 + ret i32 %z +} + +define i32 @extract_v4float_fcmp_const_no_zext_fail(<4 x float> %a, <4 x float> %b, i32 %c) { +; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext_fail: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: sub sp, sp, #16 +; CHECK-SD-NEXT: .cfi_def_cfa_offset 16 +; CHECK-SD-NEXT: fcmeq v0.4s, v0.4s, v0.4s +; CHECK-SD-NEXT: add x8, sp, #8 +; CHECK-SD-NEXT: // kill: def $w0 killed $w0 def $x0 +; CHECK-SD-NEXT: bfi x8, x0, #1, #2 +; CHECK-SD-NEXT: mvn v0.16b, v0.16b +; CHECK-SD-NEXT: xtn v0.4h, v0.4s +; CHECK-SD-NEXT: str d0, [sp, #8] +; CHECK-SD-NEXT: ldrh w8, [x8] +; CHECK-SD-NEXT: and w0, w8, #0x1 +; CHECK-SD-NEXT: add sp, sp, #16 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext_fail: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: sub sp, sp, #16 +; CHECK-GI-NEXT: .cfi_def_cfa_offset 16 +; CHECK-GI-NEXT: fmov v1.4s, #1.00000000 +; CHECK-GI-NEXT: mov w8, w0 +; CHECK-GI-NEXT: mov x9, sp +; CHECK-GI-NEXT: and x8, x8, #0x3 +; CHECK-GI-NEXT: fcmge v2.4s, v0.4s, v1.4s +; CHECK-GI-NEXT: fcmgt v0.4s, v1.4s, v0.4s +; CHECK-GI-NEXT: orr v0.16b, v0.16b, v2.16b +; CHECK-GI-NEXT: mvn v0.16b, v0.16b +; CHECK-GI-NEXT: str q0, [sp] +; CHECK-GI-NEXT: ldr w8, [x9, x8, lsl #2] +; CHECK-GI-NEXT: and w0, w8, #0x1 +; CHECK-GI-NEXT: add sp, sp, #16 +; CHECK-GI-NEXT: ret +entry: + %vector = fcmp uno <4 x float> %a, + %d = extractelement <4 x i1> %vector, i32 %c + %z = zext i1 %d to i32 + ret i32 %z +} + +define i32 @extract_v4i32_icmp_const_no_zext_fail(<4 x i32> %a, <4 x i32> %b, i32 %c) { +; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext_fail: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: sub sp, sp, #16 +; CHECK-SD-NEXT: .cfi_def_cfa_offset 16 +; CHECK-SD-NEXT: adrp x8, .LCPI45_0 +; CHECK-SD-NEXT: // kill: def $w0 killed $w0 def $x0 +; CHECK-SD-NEXT: ldr q1, [x8, :lo12:.LCPI45_0] +; CHECK-SD-NEXT: add x8, sp, #8 +; CHECK-SD-NEXT: bfi x8, x0, #1, #2 +; CHECK-SD-NEXT: cmge v0.4s, v1.4s, v0.4s +; CHECK-SD-NEXT: xtn v0.4h, v0.4s +; CHECK-SD-NEXT: str d0, [sp, #8] +; CHECK-SD-NEXT: ldrh w8, [x8] +; CHECK-SD-NEXT: and w0, w8, #0x1 +; CHECK-SD-NEXT: add sp, sp, #16 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext_fail: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: sub sp, sp, #16 +; CHECK-GI-NEXT: .cfi_def_cfa_offset 16 +; CHECK-GI-NEXT: adrp x8, .LCPI45_0 +; CHECK-GI-NEXT: mov x9, sp +; CHECK-GI-NEXT: ldr q1, [x8, :lo12:.LCPI45_0] +; CHECK-GI-NEXT: mov w8, w0 +; CHECK-GI-NEXT: and x8, x8, #0x3 +; CHECK-GI-NEXT: cmge v0.4s, v1.4s, v0.4s +; CHECK-GI-NEXT: str q0, [sp] +; CHECK-GI-NEXT: ldr w8, [x9, x8, lsl #2] +; CHECK-GI-NEXT: and w0, w8, #0x1 +; CHECK-GI-NEXT: add sp, sp, #16 +; CHECK-GI-NEXT: ret +entry: + %vector = icmp sle <4 x i32> %a, + %d = extractelement <4 x i1> %vector, i32 %c + %z = zext i1 %d to i32 + ret i32 %z +} From 27189e9beb57e8a55db32f11511f0ad2670b1dc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= Date: Mon, 13 May 2024 08:59:18 +0200 Subject: [PATCH 2/2] fixups --- llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h | 1 - llvm/include/llvm/Target/GlobalISel/Combine.td | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h index 6edb3f9cd2e89..c4af944d814f4 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h @@ -1,4 +1,3 @@ - //===-- llvm/CodeGen/GlobalISel/CombinerHelper.h --------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td index 3c71c2a25b2d9..d7aa0267fb449 100644 --- a/llvm/include/llvm/Target/GlobalISel/Combine.td +++ b/llvm/include/llvm/Target/GlobalISel/Combine.td @@ -1722,8 +1722,7 @@ def all_combines : GICombineGroup<[trivial_combines, vector_ops_combines, sub_add_reg, select_to_minmax, redundant_binop_in_equality, fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors, combine_concat_vector, double_icmp_zero_and_or_combine, match_addos, - sext_trunc, zext_trunc, combine_shuffle_concat -]>; + sext_trunc, zext_trunc, combine_shuffle_concat]>; // A combine group used to for prelegalizer combiners at -O0. The combines in // this group have been selected based on experiments to balance code size and