diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 51580d15451ca..239223813220e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1374,13 +1374,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setPrefLoopAlignment(Subtarget.getPrefLoopAlignment()); setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN, - ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL, - ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT}); + ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND, + ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT}); if (Subtarget.is64Bit()) setTargetDAGCombine(ISD::SRA); if (Subtarget.hasStdExtFOrZfinx()) - setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM}); + setTargetDAGCombine( + {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMAXNUM, ISD::FMINNUM}); if (Subtarget.hasStdExtZbb()) setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN}); @@ -12848,6 +12849,9 @@ namespace { // apply a combine. struct CombineResult; +// Supported extension kind to be folded. +enum class SupportExt { ZExt, SExt, FPExt }; + /// Helper class for folding sign/zero extensions. /// In particular, this class is used for the following combines: /// add | add_vl -> vwadd(u) | vwadd(u)_w @@ -12878,6 +12882,8 @@ struct NodeExtensionHelper { /// instance, a splat constant (e.g., 3), would support being both sign and /// zero extended. bool SupportsSExt; + /// Records if this operand is like being floating-point extended. + bool SupportsFPExt; /// This boolean captures whether we care if this operand would still be /// around after the folding happens. bool EnforceOneUse; @@ -12899,9 +12905,13 @@ struct NodeExtensionHelper { switch (OrigOperand.getOpcode()) { case ISD::ZERO_EXTEND: case ISD::SIGN_EXTEND: + case ISD::FP_EXTEND: case RISCVISD::VSEXT_VL: case RISCVISD::VZEXT_VL: + case RISCVISD::FP_EXTEND_VL: return OrigOperand.getOperand(0); + case ISD::SPLAT_VECTOR: + return OrigOperand.getOperand(0)->getOperand(0); default: return OrigOperand; } @@ -12909,7 +12919,21 @@ struct NodeExtensionHelper { /// Check if this instance represents a splat. bool isSplat() const { - return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL; + return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL || + (OrigOperand.getOpcode() == ISD::SPLAT_VECTOR && + OrigOperand.getOperand(0).getOpcode() == ISD::FP_EXTEND); + } + + /// Get the extended opcode. + unsigned getExtOpc(SupportExt Ext) const { + switch (Ext) { + case SupportExt::ZExt: + return RISCVISD::VZEXT_VL; + case SupportExt::SExt: + return RISCVISD::VSEXT_VL; + case SupportExt::FPExt: + return RISCVISD::FP_EXTEND_VL; + } } /// Get or create a value that can feed \p Root with the given extension \p @@ -12917,30 +12941,35 @@ struct NodeExtensionHelper { /// \see ::getSource(). SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG, const RISCVSubtarget &Subtarget, - std::optional SExt) const { - if (!SExt.has_value()) + std::optional Ext) const { + if (!Ext.has_value()) return OrigOperand; - MVT NarrowVT = getNarrowType(Root); + MVT NarrowVT = getNarrowType(Root, Subtarget); SDValue Source = getSource(); if (Source.getValueType() == NarrowVT) return Source; - unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL; - + unsigned OrigOpc = OrigOperand.getOpcode(); // If we need an extension, we should be changing the type. SDLoc DL(Root); auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget); - switch (OrigOperand.getOpcode()) { + switch (OrigOpc) { case ISD::ZERO_EXTEND: case ISD::SIGN_EXTEND: + case ISD::FP_EXTEND: case RISCVISD::VSEXT_VL: case RISCVISD::VZEXT_VL: + case RISCVISD::FP_EXTEND_VL: { + unsigned ExtOpc = getExtOpc(*Ext); return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL); + } + case ISD::SPLAT_VECTOR: + return DAG.getNode(ISD::SPLAT_VECTOR, DL, NarrowVT, Source); case RISCVISD::VMV_V_X_VL: - return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT, - DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL); + return DAG.getNode(OrigOpc, DL, NarrowVT, DAG.getUNDEF(NarrowVT), + Source.getOperand(1), VL); default: // Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL // and that operand should already have the right NarrowVT so no @@ -12954,67 +12983,163 @@ struct NodeExtensionHelper { /// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>. /// \pre The size of the type of the elements of Root must be a multiple of 2 /// and be greater than 16. - static MVT getNarrowType(const SDNode *Root) { + static MVT getNarrowType(const SDNode *Root, + const RISCVSubtarget &Subtarget) { MVT VT = Root->getSimpleValueType(0); // Determine the narrow size. unsigned NarrowSize = VT.getScalarSizeInBits() / 2; - assert(NarrowSize >= 8 && "Trying to extend something we can't represent"); - MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize), - VT.getVectorElementCount()); - return NarrowVT; + + MVT NarrowScalarVT = VT.isInteger() ? MVT::getIntegerVT(NarrowSize) + : MVT::getFloatingPointVT(NarrowSize); + MVT NarrowVectorVT = + MVT::getVectorVT(NarrowScalarVT, VT.getVectorElementCount()); + + assert(Subtarget.getTargetLowering()->isTypeLegal(NarrowVectorVT) && + "Trying to extend something we can't represent"); + return NarrowVectorVT; } - /// Return the opcode required to materialize the folding of the sign - /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for - /// both operands for \p Opcode. - /// Put differently, get the opcode to materialize: - /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b) - /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b) - /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()). - static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) { + /// Get full widening (2*SEW = SEW +/-/* SEW) signed integer add/sub/mul + /// opcode. + static unsigned getSignedFullWidenOpcode(unsigned Opcode) { switch (Opcode) { case ISD::ADD: case RISCVISD::ADD_VL: case RISCVISD::VWADD_W_VL: - case RISCVISD::VWADDU_W_VL: - return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL; + return RISCVISD::VWADD_VL; case ISD::MUL: case RISCVISD::MUL_VL: - return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL; + return RISCVISD::VWMUL_VL; case ISD::SUB: case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: + return RISCVISD::VWSUB_VL; + default: + llvm_unreachable("Unexpected Opcode"); + } + } + + /// Get full widening (2*SEW = SEW +/-/* SEW) unsigned integer add/sub/mul + /// opcode. + static unsigned getUnsignedFullWidenOpcode(unsigned Opcode) { + switch (Opcode) { + case ISD::ADD: + case RISCVISD::ADD_VL: + case RISCVISD::VWADDU_W_VL: + return RISCVISD::VWADDU_VL; + case ISD::MUL: + case RISCVISD::MUL_VL: + return RISCVISD::VWMULU_VL; + case ISD::SUB: + case RISCVISD::SUB_VL: case RISCVISD::VWSUBU_W_VL: - return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL; + return RISCVISD::VWSUBU_VL; default: - llvm_unreachable("Unexpected opcode"); + llvm_unreachable("Unexpected Opcode"); + } + } + + /// Get full widening (2*SEW = SEW +/-/* SEW) FP add/sub/mul opcode. + static unsigned getFloatFullWidenOpcode(unsigned Opcode) { + switch (Opcode) { + case ISD::FADD: + case RISCVISD::FADD_VL: + case RISCVISD::VFWADD_W_VL: + return RISCVISD::VFWADD_VL; + case ISD::FSUB: + case RISCVISD::FSUB_VL: + case RISCVISD::VFWSUB_W_VL: + return RISCVISD::VFWSUB_VL; + case ISD::FMUL: + case RISCVISD::FMUL_VL: + return RISCVISD::VFWMUL_VL; + default: + llvm_unreachable("Unexpected Opcode"); + } + } + + /// Return the opcode required to materialize the folding of the sign + /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for + /// both operands for \p Opcode. + /// Put differently, get the opcode to materialize: + /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b) + /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b) + /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()). + static unsigned getFullWidenOpcode(unsigned OrigOpcode, SupportExt Ext) { + switch (Ext) { + case SupportExt::SExt: + return getSignedFullWidenOpcode(OrigOpcode); + case SupportExt::ZExt: + return getUnsignedFullWidenOpcode(OrigOpcode); + case SupportExt::FPExt: + return getFloatFullWidenOpcode(OrigOpcode); } } /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) -> /// newOpcode(a, b). - static unsigned getSUOpcode(unsigned Opcode) { + static unsigned getSignedUnsignedWidenOpcode(unsigned Opcode) { assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) && "SU is only supported for MUL"); return RISCVISD::VWMULSU_VL; } - /// Get the opcode to materialize \p Opcode(a, s|zext(b)) -> - /// newOpcode(a, b). - static unsigned getWOpcode(unsigned Opcode, bool IsSExt) { + /// Get half widening (2*SEW = 2*SEW +/- SEW) signed integer add/sub opcode. + static unsigned getSignedHalfWidenOpcode(unsigned Opcode) { switch (Opcode) { case ISD::ADD: case RISCVISD::ADD_VL: - return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL; + return RISCVISD::VWADD_W_VL; case ISD::SUB: case RISCVISD::SUB_VL: - return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL; + return RISCVISD::VWSUB_W_VL; default: llvm_unreachable("Unexpected opcode"); } } + /// Get half widening (2*SEW = 2*SEW +/- SEW) unsigned integer add/sub opcode. + static unsigned getUnsignedHalfWidenOpcode(unsigned Opcode) { + switch (Opcode) { + case ISD::ADD: + case RISCVISD::ADD_VL: + return RISCVISD::VWADDU_W_VL; + case ISD::SUB: + case RISCVISD::SUB_VL: + return RISCVISD::VWSUBU_W_VL; + default: + llvm_unreachable("Unexpected opcode"); + } + } + + /// Get half widening (2*SEW = 2*SEW +/- SEW) FP add/sub opcode. + static unsigned getFloatHalfWidenOpcode(unsigned Opcode) { + switch (Opcode) { + case ISD::FADD: + case RISCVISD::FADD_VL: + return RISCVISD::VFWADD_W_VL; + case ISD::FSUB: + case RISCVISD::FSUB_VL: + return RISCVISD::VFWSUB_W_VL; + default: + llvm_unreachable("Unexpected opcode"); + } + } + + /// Get the opcode to materialize \p Opcode(a, s|zext(b)) -> + /// newOpcode(a, b). + static unsigned getHalfWidenOpcode(unsigned Opcode, SupportExt Ext) { + switch (Ext) { + case SupportExt::SExt: + return getSignedHalfWidenOpcode(Opcode); + case SupportExt::ZExt: + return getUnsignedHalfWidenOpcode(Opcode); + case SupportExt::FPExt: + return getFloatHalfWidenOpcode(Opcode); + } + } + using CombineToTry = std::function( SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/, const NodeExtensionHelper & /*RHS*/, SelectionDAG &, @@ -13029,15 +13154,18 @@ struct NodeExtensionHelper { const RISCVSubtarget &Subtarget) { SupportsZExt = false; SupportsSExt = false; + SupportsFPExt = false; EnforceOneUse = true; CheckMask = true; unsigned Opc = OrigOperand.getOpcode(); switch (Opc) { case ISD::ZERO_EXTEND: - case ISD::SIGN_EXTEND: { + case ISD::SIGN_EXTEND: + case ISD::FP_EXTEND: { if (OrigOperand.getValueType().isVector()) { SupportsZExt = Opc == ISD::ZERO_EXTEND; SupportsSExt = Opc == ISD::SIGN_EXTEND; + SupportsFPExt = Opc == ISD::FP_EXTEND; SDLoc DL(Root); MVT VT = Root->getSimpleValueType(0); std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget); @@ -13045,15 +13173,28 @@ struct NodeExtensionHelper { break; } case RISCVISD::VZEXT_VL: - SupportsZExt = true; - Mask = OrigOperand.getOperand(1); - VL = OrigOperand.getOperand(2); - break; case RISCVISD::VSEXT_VL: - SupportsSExt = true; + case RISCVISD::FP_EXTEND_VL: + SupportsZExt = Opc == RISCVISD::VZEXT_VL; + SupportsSExt = Opc == RISCVISD::VSEXT_VL; + SupportsFPExt = Opc == RISCVISD::FP_EXTEND_VL; Mask = OrigOperand.getOperand(1); VL = OrigOperand.getOperand(2); break; + case ISD::SPLAT_VECTOR: { + if (OrigOperand.getOperand(0)->getOpcode() != ISD::FP_EXTEND) + break; + + MVT VT = Root->getSimpleValueType(0); + if (VT.isFloatingPoint() && + VT.getScalarSizeInBits() >= (Subtarget.hasStdExtZvfh() ? 32 : 64)) { + SupportsFPExt = true; + SDLoc DL(Root); + std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget); + break; + } + break; + } case RISCVISD::VMV_V_X_VL: { // Historically, we didn't care about splat values not disappearing during // combines. @@ -13099,8 +13240,16 @@ struct NodeExtensionHelper { } /// Check if \p Root supports any extension folding combines. - static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) { + static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { switch (Root->getOpcode()) { + case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: + if (Root->getValueType(0).getScalarSizeInBits() < + (Subtarget.hasStdExtZvfh() ? 32 : 64)) + return false; + [[fallthrough]]; case ISD::ADD: case ISD::SUB: case ISD::MUL: { @@ -13109,6 +13258,13 @@ struct NodeExtensionHelper { return false; return Root->getValueType(0).isScalableVector(); } + case RISCVISD::FADD_VL: + case RISCVISD::FSUB_VL: + case RISCVISD::FMUL_VL: + if (Root->getValueType(0).getScalarSizeInBits() < + (Subtarget.hasStdExtZvfh() ? 32 : 64)) + return false; + [[fallthrough]]; case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::VWADD_W_VL: @@ -13116,6 +13272,8 @@ struct NodeExtensionHelper { case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: + case RISCVISD::VFWADD_W_VL: + case RISCVISD::VFWSUB_W_VL: return true; default: return false; @@ -13125,8 +13283,9 @@ struct NodeExtensionHelper { /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx). NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an " - "unsupported root"); + assert(isSupportedRoot(Root, DAG, Subtarget) && + "Trying to build an helper with an " + "unsupported root"); assert(OperandIdx < 2 && "Requesting something else than LHS or RHS"); OrigOperand = Root->getOperand(OperandIdx); @@ -13138,10 +13297,15 @@ struct NodeExtensionHelper { case RISCVISD::VWADDU_W_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: + case RISCVISD::VFWADD_W_VL: + case RISCVISD::VFWSUB_W_VL: if (OperandIdx == 1) { SupportsZExt = Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL; - SupportsSExt = !SupportsZExt; + SupportsSExt = + Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL; + SupportsFPExt = + Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL; std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget); CheckMask = true; // There's no existing extension here, so we don't have to worry about @@ -13170,11 +13334,14 @@ struct NodeExtensionHelper { static std::pair getMaskAndVL(const SDNode *Root, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(isSupportedRoot(Root, DAG) && "Unexpected root"); + assert(isSupportedRoot(Root, DAG, Subtarget) && "Unexpected root"); switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: - case ISD::MUL: { + case ISD::MUL: + case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: { SDLoc DL(Root); MVT VT = Root->getSimpleValueType(0); return getDefaultScalableVLOps(VT, DL, DAG, Subtarget); @@ -13197,15 +13364,23 @@ struct NodeExtensionHelper { switch (N->getOpcode()) { case ISD::ADD: case ISD::MUL: + case ISD::FADD: + case ISD::FMUL: case RISCVISD::ADD_VL: case RISCVISD::MUL_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: + case RISCVISD::FADD_VL: + case RISCVISD::FMUL_VL: + case RISCVISD::VFWADD_W_VL: return true; case ISD::SUB: + case ISD::FSUB: case RISCVISD::SUB_VL: case RISCVISD::VWSUB_W_VL: case RISCVISD::VWSUBU_W_VL: + case RISCVISD::FSUB_VL: + case RISCVISD::VFWSUB_W_VL: return false; default: llvm_unreachable("Unexpected opcode"); @@ -13227,22 +13402,23 @@ struct NodeExtensionHelper { struct CombineResult { /// Opcode to be generated when materializing the combine. unsigned TargetOpcode; - // No value means no extension is needed. If extension is needed, the value - // indicates if it needs to be sign extended. - std::optional SExtLHS; - std::optional SExtRHS; - /// Root of the combine. - SDNode *Root; /// LHS of the TargetOpcode. NodeExtensionHelper LHS; + /// Extension of the LHS + std::optional ExtLHS; /// RHS of the TargetOpcode. NodeExtensionHelper RHS; + /// Extension of the RHS + std::optional ExtRHS; + /// Root of the combine. + SDNode *Root; - CombineResult(unsigned TargetOpcode, SDNode *Root, - const NodeExtensionHelper &LHS, std::optional SExtLHS, - const NodeExtensionHelper &RHS, std::optional SExtRHS) - : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS), - Root(Root), LHS(LHS), RHS(RHS) {} + CombineResult( + unsigned TargetOpcode, SDNode *Root, + std::pair> LHS, + std::pair> RHS) + : TargetOpcode(TargetOpcode), LHS(LHS.first), ExtLHS(LHS.second), + RHS(RHS.first), ExtRHS(RHS.second), Root(Root) {} /// Return a value that uses TargetOpcode and that can be used to replace /// Root. @@ -13259,12 +13435,15 @@ struct CombineResult { case ISD::ADD: case ISD::SUB: case ISD::MUL: + case ISD::FADD: + case ISD::FSUB: + case ISD::FMUL: Merge = DAG.getUNDEF(Root->getValueType(0)); break; } return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0), - LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS), - RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS), + LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtLHS), + RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtRHS), Merge, Mask, VL); } }; @@ -13279,24 +13458,30 @@ struct CombineResult { /// /// \returns std::nullopt if the pattern doesn't match or a CombineResult that /// can be used to apply the pattern. -static std::optional -canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS, - const NodeExtensionHelper &RHS, bool AllowSExt, - bool AllowZExt, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { - assert((AllowSExt || AllowZExt) && "Forgot to set what you want?"); +static std::optional canFoldToVWWithSameExtensionImpl( + SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt, + bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { + assert((AllowSExt || AllowZExt || AllowFPExt) && + "Forgot to set what you want?"); if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) || !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) return std::nullopt; if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt) - return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( - Root->getOpcode(), /*IsSExt=*/false), - Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false); + return CombineResult(NodeExtensionHelper::getFullWidenOpcode( + Root->getOpcode(), SupportExt::ZExt), + Root, {LHS, SupportExt::ZExt}, + {RHS, SupportExt::ZExt}); if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt) - return CombineResult(NodeExtensionHelper::getSameExtensionOpcode( - Root->getOpcode(), /*IsSExt=*/true), - Root, LHS, /*SExtLHS=*/true, RHS, - /*SExtRHS=*/true); + return CombineResult(NodeExtensionHelper::getFullWidenOpcode( + Root->getOpcode(), SupportExt::SExt), + Root, {LHS, SupportExt::SExt}, + {RHS, SupportExt::SExt}); + if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt) + return CombineResult(NodeExtensionHelper::getFullWidenOpcode( + Root->getOpcode(), SupportExt::FPExt), + Root, {LHS, SupportExt::FPExt}, + {RHS, SupportExt::FPExt}); return std::nullopt; } @@ -13311,7 +13496,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, - /*AllowZExt=*/true, DAG, Subtarget); + /*AllowZExt=*/true, true, DAG, + Subtarget); } /// Check if \p Root follows a pattern Root(LHS, ext(RHS)) @@ -13330,13 +13516,17 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS, // Control this behavior behind an option (AllowSplatInVW_W) for testing // purposes. if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W)) - return CombineResult( - NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false), - Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/false); + return CombineResult(NodeExtensionHelper::getHalfWidenOpcode( + Root->getOpcode(), SupportExt::ZExt), + Root, {LHS, std::nullopt}, {RHS, SupportExt::ZExt}); if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W)) - return CombineResult( - NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true), - Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/true); + return CombineResult(NodeExtensionHelper::getHalfWidenOpcode( + Root->getOpcode(), SupportExt::SExt), + Root, {LHS, std::nullopt}, {RHS, SupportExt::SExt}); + if (RHS.SupportsFPExt && (!RHS.isSplat() || AllowSplatInVW_W)) + return CombineResult(NodeExtensionHelper::getHalfWidenOpcode( + Root->getOpcode(), SupportExt::FPExt), + Root, {LHS, std::nullopt}, {RHS, SupportExt::FPExt}); return std::nullopt; } @@ -13349,7 +13539,8 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true, - /*AllowZExt=*/false, DAG, Subtarget); + /*AllowZExt=*/false, false, DAG, + Subtarget); } /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS)) @@ -13361,7 +13552,17 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS, const NodeExtensionHelper &RHS, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false, - /*AllowZExt=*/true, DAG, Subtarget); + /*AllowZExt=*/true, false, DAG, + Subtarget); +} + +static std::optional +canFoldToVFWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS, + const NodeExtensionHelper &RHS, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false, + /*AllowZExt=*/false, true, DAG, + Subtarget); } /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS)) @@ -13378,8 +13579,9 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS, if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) || !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget)) return std::nullopt; - return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()), - Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false); + return CombineResult( + NodeExtensionHelper::getSignedUnsignedWidenOpcode(Root->getOpcode()), + Root, {LHS, SupportExt::SExt}, {RHS, SupportExt::ZExt}); } SmallVector @@ -13388,13 +13590,21 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: + case ISD::FADD: + case ISD::FSUB: case RISCVISD::ADD_VL: case RISCVISD::SUB_VL: + case RISCVISD::FADD_VL: + case RISCVISD::FSUB_VL: // add|sub -> vwadd(u)|vwsub(u) Strategies.push_back(canFoldToVWWithSameExtension); // add|sub -> vwadd(u)_w|vwsub(u)_w Strategies.push_back(canFoldToVW_W); break; + case ISD::FMUL: + case RISCVISD::FMUL_VL: + Strategies.push_back(canFoldToVWWithSameExtension); + break; case ISD::MUL: case RISCVISD::MUL_VL: // mul -> vwmul(u) @@ -13412,6 +13622,10 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) { // vwaddu_w|vwsubu_w -> vwaddu|vwsubu Strategies.push_back(canFoldToVWWithZEXT); break; + case RISCVISD::VFWADD_W_VL: + case RISCVISD::VFWSUB_W_VL: + Strategies.push_back(canFoldToVFWWithFPEXT); + break; default: llvm_unreachable("Unexpected opcode"); } @@ -13431,7 +13645,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, const RISCVSubtarget &Subtarget) { SelectionDAG &DAG = DCI.DAG; - if (!NodeExtensionHelper::isSupportedRoot(N, DAG)) + if (!NodeExtensionHelper::isSupportedRoot(N, DAG, Subtarget)) return SDValue(); SmallVector Worklist; @@ -13442,7 +13656,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, while (!Worklist.empty()) { SDNode *Root = Worklist.pop_back_val(); - if (!NodeExtensionHelper::isSupportedRoot(Root, DAG)) + if (!NodeExtensionHelper::isSupportedRoot(Root, DAG, Subtarget)) return SDValue(); NodeExtensionHelper LHS(N, 0, DAG, Subtarget); @@ -13481,9 +13695,9 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, // All the inputs that are extended need to be folded, otherwise // we would be leaving the old input (since it is may still be used), // and the new one. - if (Res->SExtLHS.has_value()) + if (Res->ExtLHS.has_value()) AppendUsersIfNeeded(LHS); - if (Res->SExtRHS.has_value()) + if (Res->ExtRHS.has_value()) AppendUsersIfNeeded(RHS); break; } @@ -13990,7 +14204,8 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG, N->getOperand(2), Mask, VL); } -static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG, +static SDValue performVFMUL_VLCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { if (N->getValueType(0).isScalableVector() && N->getValueType(0).getVectorElementType() == MVT::f32 && @@ -14002,36 +14217,11 @@ static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG, // FIXME: Ignore strict opcodes for now. assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode"); - // Try to form widening multiply. - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); - - if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL || - Op1.getOpcode() != RISCVISD::FP_EXTEND_VL) - return SDValue(); - - // TODO: Refactor to handle more complex cases similar to - // combineBinOp_VLToVWBinOp_VL. - if ((!Op0.hasOneUse() || !Op1.hasOneUse()) && - (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0))) - return SDValue(); - - // Check the mask and VL are the same. - if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL || - Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL) - return SDValue(); - - Op0 = Op0.getOperand(0); - Op1 = Op1.getOperand(0); - - return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0, - Op1, Merge, Mask, VL); + return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget); } -static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG, +static SDValue performFADDSUB_VLCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { if (N->getValueType(0).isScalableVector() && N->getValueType(0).getVectorElementType() == MVT::f32 && @@ -14040,55 +14230,7 @@ static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG, return SDValue(); } - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - SDValue Merge = N->getOperand(2); - SDValue Mask = N->getOperand(3); - SDValue VL = N->getOperand(4); - - bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL; - - // Look for foldable FP_EXTENDS. - bool Op0IsExtend = - Op0.getOpcode() == RISCVISD::FP_EXTEND_VL && - (Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0))); - bool Op1IsExtend = - (Op0 == Op1 && Op0IsExtend) || - (Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse()); - - // Check the mask and VL. - if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL)) - Op0IsExtend = false; - if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)) - Op1IsExtend = false; - - // Canonicalize. - if (!Op1IsExtend) { - // Sub requires at least operand 1 to be an extend. - if (!IsAdd) - return SDValue(); - - // Add is commutable, if the other operand is foldable, swap them. - if (!Op0IsExtend) - return SDValue(); - - std::swap(Op0, Op1); - std::swap(Op0IsExtend, Op1IsExtend); - } - - // Op1 is a foldable extend. Op0 might be foldable. - Op1 = Op1.getOperand(0); - if (Op0IsExtend) - Op0 = Op0.getOperand(0); - - unsigned Opc; - if (IsAdd) - Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL; - else - Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL; - - return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask, - VL); + return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget); } static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG, @@ -15109,7 +15251,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) return V; return performMULCombine(N, DAG); + case ISD::FSUB: + case ISD::FMUL: + return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget); case ISD::FADD: + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) + return V; + [[fallthrough]]; case ISD::UMAX: case ISD::UMIN: case ISD::SMAX: @@ -15604,10 +15752,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case RISCVISD::STRICT_VFNMSUB_VL: return performVFMADD_VLCombine(N, DAG, Subtarget); case RISCVISD::FMUL_VL: - return performVFMUL_VLCombine(N, DAG, Subtarget); + return performVFMUL_VLCombine(N, DCI, Subtarget); case RISCVISD::FADD_VL: case RISCVISD::FSUB_VL: - return performFADDSUB_VLCombine(N, DAG, Subtarget); + return performFADDSUB_VLCombine(N, DCI, Subtarget); case ISD::LOAD: case ISD::STORE: { if (DCI.isAfterLegalizeDAG()) diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll index c9dc75e18774f..dd3a50cfd7737 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll @@ -396,12 +396,10 @@ define <32 x double> @vfwadd_vf_v32f32(ptr %x, float %y) { ; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma ; CHECK-NEXT: vle32.v v24, (a0) ; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma -; CHECK-NEXT: vslidedown.vi v0, v24, 16 +; CHECK-NEXT: vslidedown.vi v8, v24, 16 ; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfmv.v.f v16, fa0 -; CHECK-NEXT: vfwcvt.f.f.v v8, v16 -; CHECK-NEXT: vfwadd.wv v16, v8, v0 -; CHECK-NEXT: vfwadd.wv v8, v8, v24 +; CHECK-NEXT: vfwadd.vf v16, v8, fa0 +; CHECK-NEXT: vfwadd.vf v8, v24, fa0 ; CHECK-NEXT: ret %a = load <32 x float>, ptr %x %b = insertelement <32 x float> poison, float %y, i32 0 diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll index 8ad858d4c7659..7eaa1856ce221 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll @@ -394,18 +394,12 @@ define <32 x double> @vfwmul_vf_v32f32(ptr %x, float %y) { ; CHECK: # %bb.0: ; CHECK-NEXT: li a1, 32 ; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma -; CHECK-NEXT: vle32.v v16, (a0) -; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfwcvt.f.f.v v8, v16 +; CHECK-NEXT: vle32.v v24, (a0) ; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma -; CHECK-NEXT: vslidedown.vi v16, v16, 16 +; CHECK-NEXT: vslidedown.vi v8, v24, 16 ; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfwcvt.f.f.v v24, v16 -; CHECK-NEXT: vfmv.v.f v16, fa0 -; CHECK-NEXT: vfwcvt.f.f.v v0, v16 -; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma -; CHECK-NEXT: vfmul.vv v16, v24, v0 -; CHECK-NEXT: vfmul.vv v8, v8, v0 +; CHECK-NEXT: vfwmul.vf v16, v8, fa0 +; CHECK-NEXT: vfwmul.vf v8, v24, fa0 ; CHECK-NEXT: ret %a = load <32 x float>, ptr %x %b = insertelement <32 x float> poison, float %y, i32 0 diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll index d22781d6a97ac..8cf7c5f175865 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll @@ -394,18 +394,12 @@ define <32 x double> @vfwsub_vf_v32f32(ptr %x, float %y) { ; CHECK: # %bb.0: ; CHECK-NEXT: li a1, 32 ; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma -; CHECK-NEXT: vle32.v v16, (a0) -; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfwcvt.f.f.v v8, v16 +; CHECK-NEXT: vle32.v v24, (a0) ; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma -; CHECK-NEXT: vslidedown.vi v16, v16, 16 +; CHECK-NEXT: vslidedown.vi v8, v24, 16 ; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma -; CHECK-NEXT: vfwcvt.f.f.v v24, v16 -; CHECK-NEXT: vfmv.v.f v16, fa0 -; CHECK-NEXT: vfwcvt.f.f.v v0, v16 -; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma -; CHECK-NEXT: vfsub.vv v16, v24, v0 -; CHECK-NEXT: vfsub.vv v8, v8, v0 +; CHECK-NEXT: vfwsub.vf v16, v8, fa0 +; CHECK-NEXT: vfwsub.vf v8, v24, fa0 ; CHECK-NEXT: ret %a = load <32 x float>, ptr %x %b = insertelement <32 x float> poison, float %y, i32 0