diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h index ecaece8b68342..c4af944d814f4 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h @@ -866,6 +866,16 @@ class CombinerHelper { /// Combine insert vector element OOB. bool matchInsertVectorElementOOB(MachineInstr &MI, BuildFnTy &MatchInfo); + /// Combine extract vector element with a compare on the vector + /// register. + bool matchExtractVectorElementWithICmp(const MachineOperand &MO, + BuildFnTy &MatchInfo); + + /// Combine extract vector element with a compare on the vector + /// register. + bool matchExtractVectorElementWithFCmp(const MachineOperand &MO, + BuildFnTy &MatchInfo); + private: /// Checks for legality of an indexed variant of \p LdSt. bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const; @@ -981,6 +991,18 @@ class CombinerHelper { // Simplify (cmp cc0 x, y) (&& or ||) (cmp cc1 x, y) -> cmp cc2 x, y. bool tryFoldLogicOfFCmps(GLogicalBinOp *Logic, BuildFnTy &MatchInfo); + + /// Return true if the register \p Src is cheaper to scalarize than it is to + /// leave as a vector operation. If the extract index \p Index is a constant + /// integer then some operations may be cheap to scalarize. The depth \p Depth + /// prevents arbitrary recursion. + bool isCheapToScalarize(Register Src, const std::optional &Index, + unsigned Depth = 0); + + /// Return true if \p Src is def'd by a operation of type vector that is + /// constant at offset \p Index. \p Depth limits arbitrary recursion into look + /// through vector operations. + bool isConstantAtOffset(Register Src, const APInt &Index, unsigned Depth = 0); }; } // namespace llvm diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td index 98d266c8c0b4f..d7aa0267fb449 100644 --- a/llvm/include/llvm/Target/GlobalISel/Combine.td +++ b/llvm/include/llvm/Target/GlobalISel/Combine.td @@ -1591,6 +1591,20 @@ def insert_vector_elt_oob : GICombineRule< [{ return Helper.matchInsertVectorElementOOB(*${root}, ${matchinfo}); }]), (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>; +def extract_vector_element_icmp : GICombineRule< + (defs root:$root, build_fn_matchinfo:$matchinfo), + (match (G_ICMP $src, $pred, $lhs, $rhs), + (G_EXTRACT_VECTOR_ELT $root, $src, $idx), + [{ return Helper.matchExtractVectorElementWithICmp(${root}, ${matchinfo}); }]), + (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>; + +def extract_vector_element_fcmp : GICombineRule< + (defs root:$root, build_fn_matchinfo:$matchinfo), + (match (G_FCMP $fsrc, $fpred, $flhs, $frhs), + (G_EXTRACT_VECTOR_ELT $root, $fsrc, $fidx), + [{ return Helper.matchExtractVectorElementWithFCmp(${root}, ${matchinfo}); }]), + (apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>; + // match_extract_of_element and insert_vector_elt_oob must be the first! def vector_ops_combines: GICombineGroup<[ match_extract_of_element_undef_vector, @@ -1624,6 +1638,8 @@ extract_vector_element_build_vector_trunc7, extract_vector_element_build_vector_trunc8, extract_vector_element_freeze, extract_vector_element_shuffle_vector, +extract_vector_element_icmp, +extract_vector_element_fcmp, insert_vector_element_extract_vector_element ]>; diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp index 21b1eb2628174..64b39e3f82e65 100644 --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelperVectorOps.cpp @@ -453,3 +453,162 @@ bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI, return false; } + +bool CombinerHelper::isConstantAtOffset(Register Src, const APInt &Index, + unsigned Depth) { + assert(MRI.getType(Src).isVector() && "expected a vector as input"); + if (Depth == 2) + return false; + + // We use the look through variant for higher hit rate and to increase the + // likelyhood of constant folding. The actual value is ignored. We only test + // *whether* there is a constant. + + MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI); + + // If Src is def'd by build vector, then we check the constness at the offset. + if (auto *Build = dyn_cast(SrcMI)) + return getAnyConstantVRegValWithLookThrough( + Build->getSourceReg(Index.getZExtValue()), MRI) + .has_value(); + + // For concat and shuffle vectors, we could recurse. + // FIXME concat vectors + // FIXME shuffle vectors + // FIXME unary ops + // FIXME insert vector element + // FIXME subvector + + return false; +} + +bool CombinerHelper::isCheapToScalarize(Register Src, + const std::optional &Index, + unsigned Depth) { + assert(MRI.getType(Src).isVector() && "expected a vector as input"); + + if (Depth >= 2) + return false; + + MachineInstr *SrcMI = getDefIgnoringCopies(Src, MRI); + + // If Src is def'd by a binary operator, + // then scalarizing the op is cheap when one of its operands is cheap to + // scalarize. + if (auto *BinOp = dyn_cast(SrcMI)) + if (MRI.hasOneNonDBGUse(BinOp->getReg(0))) + if (isCheapToScalarize(BinOp->getLHSReg(), Index, Depth + 1) || + isCheapToScalarize(BinOp->getRHSReg(), Index, Depth + 1)) + return true; + + // If Src is def'd by a compare, + // then scalarizing the cmp is cheap when one of its operands is cheap to + // scalarize. + if (auto *Cmp = dyn_cast(SrcMI)) + if (MRI.hasOneNonDBGUse(Cmp->getReg(0))) + if (isCheapToScalarize(Cmp->getLHSReg(), Index, Depth + 1) || + isCheapToScalarize(Cmp->getRHSReg(), Index, Depth + 1)) + return true; + + // FIXME: unary operator + // FIXME: casts + // FIXME: loads + // FIXME: subvector + + if (Index) + // If Index is constant, then Src is cheap to scalarize when it is constant + // at offset Index. + return isConstantAtOffset(Src, *Index, Depth); + + return false; +} + +bool CombinerHelper::matchExtractVectorElementWithICmp(const MachineOperand &MO, + BuildFnTy &MatchInfo) { + GExtractVectorElement *Extract = + cast(MRI.getVRegDef(MO.getReg())); + + Register Vector = Extract->getVectorReg(); + + GICmp *Cmp = cast(MRI.getVRegDef(Vector)); + + std::optional MaybeIndex = + getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI); + std::optional IndexC = std::nullopt; + + if (MaybeIndex) + IndexC = MaybeIndex->Value; + + if (!isCheapToScalarize(Vector, IndexC)) + return false; + + if (!MRI.hasOneNonDBGUse(Cmp->getReg(0))) + return false; + + Register Dst = Extract->getReg(0); + LLT DstTy = MRI.getType(Dst); + LLT IdxTy = MRI.getType(Extract->getIndexReg()); + LLT VectorTy = MRI.getType(Cmp->getLHSReg()); + LLT ExtractDstTy = VectorTy.getScalarType(); + + if (!isLegalOrBeforeLegalizer( + {TargetOpcode::G_ICMP, {DstTy, ExtractDstTy}}) || + !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT, + {ExtractDstTy, VectorTy, IdxTy}})) + return false; + + MatchInfo = [=](MachineIRBuilder &B) { + auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(), + Extract->getIndexReg()); + auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(), + Extract->getIndexReg()); + B.buildICmp(Cmp->getCond(), Dst, LHS, RHS); + }; + + return true; +} + +bool CombinerHelper::matchExtractVectorElementWithFCmp(const MachineOperand &MO, + BuildFnTy &MatchInfo) { + GExtractVectorElement *Extract = + cast(MRI.getVRegDef(MO.getReg())); + + Register Vector = Extract->getVectorReg(); + + GFCmp *Cmp = cast(MRI.getVRegDef(Vector)); + + std::optional MaybeIndex = + getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI); + std::optional IndexC = std::nullopt; + + if (MaybeIndex) + IndexC = MaybeIndex->Value; + + if (!isCheapToScalarize(Vector, IndexC)) + return false; + + if (!MRI.hasOneNonDBGUse(Cmp->getReg(0))) + return false; + + Register Dst = Extract->getReg(0); + LLT DstTy = MRI.getType(Dst); + LLT IdxTy = MRI.getType(Extract->getIndexReg()); + LLT VectorTy = MRI.getType(Cmp->getLHSReg()); + LLT ExtractDstTy = VectorTy.getScalarType(); + + if (!isLegalOrBeforeLegalizer( + {TargetOpcode::G_FCMP, {DstTy, ExtractDstTy}}) || + !isLegalOrBeforeLegalizer({TargetOpcode::G_EXTRACT_VECTOR_ELT, + {ExtractDstTy, VectorTy, IdxTy}})) + return false; + + MatchInfo = [=](MachineIRBuilder &B) { + auto LHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getLHSReg(), + Extract->getIndexReg()); + auto RHS = B.buildExtractVectorElement(ExtractDstTy, Cmp->getRHSReg(), + Extract->getIndexReg()); + B.buildFCmp(Cmp->getCond(), Dst, LHS, RHS, Cmp->getFlags()); + }; + + return true; +} diff --git a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll index 0481d997d24fa..42fe5e82cb7de 100644 --- a/llvm/test/CodeGen/AArch64/extract-vector-elt.ll +++ b/llvm/test/CodeGen/AArch64/extract-vector-elt.ll @@ -1100,4 +1100,132 @@ ret: ret i32 %3 } +define i32 @extract_v4float_fcmp_const_no_zext(<4 x float> %a, <4 x float> %b, i32 %c) { +; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: fcmeq v0.4s, v0.4s, v0.4s +; CHECK-SD-NEXT: mvn v0.16b, v0.16b +; CHECK-SD-NEXT: xtn v0.4h, v0.4s +; CHECK-SD-NEXT: umov w8, v0.h[1] +; CHECK-SD-NEXT: and w0, w8, #0x1 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: mov s0, v0.s[1] +; CHECK-GI-NEXT: fmov s1, #1.00000000 +; CHECK-GI-NEXT: fcmp s0, s1 +; CHECK-GI-NEXT: cset w0, vs +; CHECK-GI-NEXT: ret +entry: + %vector = fcmp uno <4 x float> %a, + %d = extractelement <4 x i1> %vector, i32 1 + %z = zext i1 %d to i32 + ret i32 %z +} +define i32 @extract_v4i32_icmp_const_no_zext(<4 x i32> %a, <4 x i32> %b, i32 %c) { +; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: adrp x8, .LCPI43_0 +; CHECK-SD-NEXT: ldr q1, [x8, :lo12:.LCPI43_0] +; CHECK-SD-NEXT: cmge v0.4s, v1.4s, v0.4s +; CHECK-SD-NEXT: xtn v0.4h, v0.4s +; CHECK-SD-NEXT: umov w8, v0.h[1] +; CHECK-SD-NEXT: and w0, w8, #0x1 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: mov s0, v0.s[1] +; CHECK-GI-NEXT: fmov w8, s0 +; CHECK-GI-NEXT: cmp w8, #8 +; CHECK-GI-NEXT: cset w0, le +; CHECK-GI-NEXT: ret +entry: + %vector = icmp sle <4 x i32> %a, + %d = extractelement <4 x i1> %vector, i32 1 + %z = zext i1 %d to i32 + ret i32 %z +} + +define i32 @extract_v4float_fcmp_const_no_zext_fail(<4 x float> %a, <4 x float> %b, i32 %c) { +; CHECK-SD-LABEL: extract_v4float_fcmp_const_no_zext_fail: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: sub sp, sp, #16 +; CHECK-SD-NEXT: .cfi_def_cfa_offset 16 +; CHECK-SD-NEXT: fcmeq v0.4s, v0.4s, v0.4s +; CHECK-SD-NEXT: add x8, sp, #8 +; CHECK-SD-NEXT: // kill: def $w0 killed $w0 def $x0 +; CHECK-SD-NEXT: bfi x8, x0, #1, #2 +; CHECK-SD-NEXT: mvn v0.16b, v0.16b +; CHECK-SD-NEXT: xtn v0.4h, v0.4s +; CHECK-SD-NEXT: str d0, [sp, #8] +; CHECK-SD-NEXT: ldrh w8, [x8] +; CHECK-SD-NEXT: and w0, w8, #0x1 +; CHECK-SD-NEXT: add sp, sp, #16 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: extract_v4float_fcmp_const_no_zext_fail: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: sub sp, sp, #16 +; CHECK-GI-NEXT: .cfi_def_cfa_offset 16 +; CHECK-GI-NEXT: fmov v1.4s, #1.00000000 +; CHECK-GI-NEXT: mov w8, w0 +; CHECK-GI-NEXT: mov x9, sp +; CHECK-GI-NEXT: and x8, x8, #0x3 +; CHECK-GI-NEXT: fcmge v2.4s, v0.4s, v1.4s +; CHECK-GI-NEXT: fcmgt v0.4s, v1.4s, v0.4s +; CHECK-GI-NEXT: orr v0.16b, v0.16b, v2.16b +; CHECK-GI-NEXT: mvn v0.16b, v0.16b +; CHECK-GI-NEXT: str q0, [sp] +; CHECK-GI-NEXT: ldr w8, [x9, x8, lsl #2] +; CHECK-GI-NEXT: and w0, w8, #0x1 +; CHECK-GI-NEXT: add sp, sp, #16 +; CHECK-GI-NEXT: ret +entry: + %vector = fcmp uno <4 x float> %a, + %d = extractelement <4 x i1> %vector, i32 %c + %z = zext i1 %d to i32 + ret i32 %z +} + +define i32 @extract_v4i32_icmp_const_no_zext_fail(<4 x i32> %a, <4 x i32> %b, i32 %c) { +; CHECK-SD-LABEL: extract_v4i32_icmp_const_no_zext_fail: +; CHECK-SD: // %bb.0: // %entry +; CHECK-SD-NEXT: sub sp, sp, #16 +; CHECK-SD-NEXT: .cfi_def_cfa_offset 16 +; CHECK-SD-NEXT: adrp x8, .LCPI45_0 +; CHECK-SD-NEXT: // kill: def $w0 killed $w0 def $x0 +; CHECK-SD-NEXT: ldr q1, [x8, :lo12:.LCPI45_0] +; CHECK-SD-NEXT: add x8, sp, #8 +; CHECK-SD-NEXT: bfi x8, x0, #1, #2 +; CHECK-SD-NEXT: cmge v0.4s, v1.4s, v0.4s +; CHECK-SD-NEXT: xtn v0.4h, v0.4s +; CHECK-SD-NEXT: str d0, [sp, #8] +; CHECK-SD-NEXT: ldrh w8, [x8] +; CHECK-SD-NEXT: and w0, w8, #0x1 +; CHECK-SD-NEXT: add sp, sp, #16 +; CHECK-SD-NEXT: ret +; +; CHECK-GI-LABEL: extract_v4i32_icmp_const_no_zext_fail: +; CHECK-GI: // %bb.0: // %entry +; CHECK-GI-NEXT: sub sp, sp, #16 +; CHECK-GI-NEXT: .cfi_def_cfa_offset 16 +; CHECK-GI-NEXT: adrp x8, .LCPI45_0 +; CHECK-GI-NEXT: mov x9, sp +; CHECK-GI-NEXT: ldr q1, [x8, :lo12:.LCPI45_0] +; CHECK-GI-NEXT: mov w8, w0 +; CHECK-GI-NEXT: and x8, x8, #0x3 +; CHECK-GI-NEXT: cmge v0.4s, v1.4s, v0.4s +; CHECK-GI-NEXT: str q0, [sp] +; CHECK-GI-NEXT: ldr w8, [x9, x8, lsl #2] +; CHECK-GI-NEXT: and w0, w8, #0x1 +; CHECK-GI-NEXT: add sp, sp, #16 +; CHECK-GI-NEXT: ret +entry: + %vector = icmp sle <4 x i32> %a, + %d = extractelement <4 x i1> %vector, i32 %c + %z = zext i1 %d to i32 + ret i32 %z +}