diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index b0150fb6367a5..09c6218b3dfd9 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -618,6 +618,8 @@ namespace { SDValue CombineConsecutiveLoads(SDNode *N, EVT VT); SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI); + SDValue foldPartialReduceMLAMulOp(SDNode *N); + SDValue foldPartialReduceAdd(SDNode *N); SDValue CombineExtLoad(SDNode *N); SDValue CombineZExtLogicopShiftLoad(SDNode *N); @@ -12601,12 +12603,20 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { + if (SDValue Res = foldPartialReduceMLAMulOp(N)) + return Res; + if (SDValue Res = foldPartialReduceAdd(N)) + return Res; + return SDValue(); +} + // partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1)) // -> partial_reduce_*mla(acc, a, b) // // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) // -> partial_reduce_*mla(acc, x, C) -SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { +SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDLoc DL(N); auto *Context = DAG.getContext(); SDValue Acc = N->getOperand(0); @@ -12672,6 +12682,43 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { RHSExtOp); } +// partial.reduce.umla(acc, zext(op), splat(1)) +// -> partial.reduce.umla(acc, op, splat(trunc(1))) +// partial.reduce.smla(acc, sext(op), splat(1)) +// -> partial.reduce.smla(acc, op, splat(trunc(1))) +SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { + SDLoc DL(N); + SDValue Acc = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + SDValue Op2 = N->getOperand(2); + + APInt ConstantOne; + if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) || + !ConstantOne.isOne()) + return SDValue(); + + unsigned Op1Opcode = Op1.getOpcode(); + if (!ISD::isExtOpcode(Op1Opcode)) + return SDValue(); + + SDValue UnextOp1 = Op1.getOperand(0); + EVT UnextOp1VT = UnextOp1.getValueType(); + if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT)) + return SDValue(); + + bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND; + bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA; + EVT AccElemVT = Acc.getValueType().getVectorElementType(); + if (Op1IsSigned != NodeIsSigned && + Op1.getValueType().getVectorElementType() != AccElemVT) + return SDValue(); + + unsigned NewOpcode = + Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; + return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1, + DAG.getConstant(1, DL, UnextOp1VT)); +} + SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) { auto *SLD = cast(N); EVT EltVT = SLD->getValueType(0).getVectorElementType(); diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index 5326bccbbc3d5..67be3f58e8a24 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -516,16 +516,8 @@ define @udot_no_bin_op( %acc, %a to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %a.ext) @@ -541,16 +533,8 @@ define @sdot_no_bin_op( %acc, %a to %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %a.ext) @@ -566,16 +550,8 @@ define @udot_no_bin_op_wide( %acc, %a to @@ -592,16 +568,8 @@ define @sdot_no_bin_op_wide( %acc, %a to