diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index bd5b1a879f32b..72b2e5e78c299 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -7739,17 +7739,41 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } - MVT OrigContainerVT = ContainerVT; - SDValue OrigVec = Vec; // If we know the index we're going to insert at, we can shrink Vec so that // we're performing the scalar inserts and slideup on a smaller LMUL. - if (auto *CIdx = dyn_cast(Idx)) { - if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, CIdx->getZExtValue(), + MVT OrigContainerVT = ContainerVT; + SDValue OrigVec = Vec; + SDValue AlignedIdx; + if (auto *IdxC = dyn_cast(Idx)) { + const unsigned OrigIdx = IdxC->getZExtValue(); + // Do we know an upper bound on LMUL? + if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, OrigIdx, DL, DAG, Subtarget)) { ContainerVT = *ShrunkVT; + AlignedIdx = DAG.getVectorIdxConstant(0, DL); + } + + // If we're compiling for an exact VLEN value, we can always perform + // the insert in m1 as we can determine the register corresponding to + // the index in the register group. + const unsigned MinVLen = Subtarget.getRealMinVLen(); + const unsigned MaxVLen = Subtarget.getRealMaxVLen(); + const MVT M1VT = getLMUL1VT(ContainerVT); + if (MinVLen == MaxVLen && ContainerVT.bitsGT(M1VT)) { + EVT ElemVT = VecVT.getVectorElementType(); + unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits(); + unsigned RemIdx = OrigIdx % ElemsPerVReg; + unsigned SubRegIdx = OrigIdx / ElemsPerVReg; + unsigned ExtractIdx = + SubRegIdx * M1VT.getVectorElementCount().getKnownMinValue(); + AlignedIdx = DAG.getVectorIdxConstant(ExtractIdx, DL); + Idx = DAG.getVectorIdxConstant(RemIdx, DL); + ContainerVT = M1VT; + } + + if (AlignedIdx) Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec, - DAG.getVectorIdxConstant(0, DL)); - } + AlignedIdx); } MVT XLenVT = Subtarget.getXLenVT(); @@ -7779,9 +7803,9 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, Val = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Val); Vec = DAG.getNode(Opc, DL, ContainerVT, Vec, Val, VL); - if (ContainerVT != OrigContainerVT) + if (AlignedIdx) Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec, - Vec, DAG.getVectorIdxConstant(0, DL)); + Vec, AlignedIdx); if (!VecVT.isFixedLengthVector()) return Vec; return convertFromScalableVector(VecVT, Vec, DAG, Subtarget); @@ -7814,10 +7838,10 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, // Bitcast back to the right container type. ValInVec = DAG.getBitcast(ContainerVT, ValInVec); - if (ContainerVT != OrigContainerVT) + if (AlignedIdx) ValInVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec, - ValInVec, DAG.getVectorIdxConstant(0, DL)); + ValInVec, AlignedIdx); if (!VecVT.isFixedLengthVector()) return ValInVec; return convertFromScalableVector(VecVT, ValInVec, DAG, Subtarget); @@ -7849,9 +7873,9 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, SDValue Slideup = getVSlideup(DAG, Subtarget, DL, ContainerVT, Vec, ValInVec, Idx, Mask, InsertVL, Policy); - if (ContainerVT != OrigContainerVT) + if (AlignedIdx) Slideup = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec, - Slideup, DAG.getVectorIdxConstant(0, DL)); + Slideup, AlignedIdx); if (!VecVT.isFixedLengthVector()) return Slideup; return convertFromScalableVector(VecVT, Slideup, DAG, Subtarget); diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll index de5c4fbc08764..a3f41fd842222 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll @@ -614,9 +614,8 @@ define <16 x i32> @insertelt_c3_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_ra define <16 x i32> @insertelt_c12_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) { ; CHECK-LABEL: insertelt_c12_v16xi32_exact: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 13, e32, m4, tu, ma -; CHECK-NEXT: vmv.s.x v12, a0 -; CHECK-NEXT: vslideup.vi v8, v12, 12 +; CHECK-NEXT: vsetivli zero, 16, e32, m1, tu, ma +; CHECK-NEXT: vmv.s.x v11, a0 ; CHECK-NEXT: ret %v = insertelement <16 x i32> %vin, i32 %a, i32 12 ret <16 x i32> %v @@ -625,9 +624,9 @@ define <16 x i32> @insertelt_c12_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r define <16 x i32> @insertelt_c13_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) { ; CHECK-LABEL: insertelt_c13_v16xi32_exact: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 14, e32, m4, tu, ma +; CHECK-NEXT: vsetivli zero, 2, e32, m1, tu, ma ; CHECK-NEXT: vmv.s.x v12, a0 -; CHECK-NEXT: vslideup.vi v8, v12, 13 +; CHECK-NEXT: vslideup.vi v11, v12, 1 ; CHECK-NEXT: ret %v = insertelement <16 x i32> %vin, i32 %a, i32 13 ret <16 x i32> %v @@ -636,9 +635,9 @@ define <16 x i32> @insertelt_c13_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r define <16 x i32> @insertelt_c14_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) { ; CHECK-LABEL: insertelt_c14_v16xi32_exact: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 15, e32, m4, tu, ma +; CHECK-NEXT: vsetivli zero, 3, e32, m1, tu, ma ; CHECK-NEXT: vmv.s.x v12, a0 -; CHECK-NEXT: vslideup.vi v8, v12, 14 +; CHECK-NEXT: vslideup.vi v11, v12, 2 ; CHECK-NEXT: ret %v = insertelement <16 x i32> %vin, i32 %a, i32 14 ret <16 x i32> %v @@ -647,9 +646,9 @@ define <16 x i32> @insertelt_c14_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r define <16 x i32> @insertelt_c15_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) { ; CHECK-LABEL: insertelt_c15_v16xi32_exact: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma +; CHECK-NEXT: vsetivli zero, 4, e32, m1, tu, ma ; CHECK-NEXT: vmv.s.x v12, a0 -; CHECK-NEXT: vslideup.vi v8, v12, 15 +; CHECK-NEXT: vslideup.vi v11, v12, 3 ; CHECK-NEXT: ret %v = insertelement <16 x i32> %vin, i32 %a, i32 15 ret <16 x i32> %v @@ -658,18 +657,15 @@ define <16 x i32> @insertelt_c15_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r define <8 x i64> @insertelt_c4_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range(2,2) { ; RV32-LABEL: insertelt_c4_v8xi64_exact: ; RV32: # %bb.0: -; RV32-NEXT: vsetivli zero, 2, e32, m4, ta, ma -; RV32-NEXT: vslide1down.vx v12, v8, a0 -; RV32-NEXT: vslide1down.vx v12, v12, a1 -; RV32-NEXT: vsetivli zero, 5, e64, m4, tu, ma -; RV32-NEXT: vslideup.vi v8, v12, 4 +; RV32-NEXT: vsetivli zero, 2, e32, m1, tu, ma +; RV32-NEXT: vslide1down.vx v10, v10, a0 +; RV32-NEXT: vslide1down.vx v10, v10, a1 ; RV32-NEXT: ret ; ; RV64-LABEL: insertelt_c4_v8xi64_exact: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 5, e64, m4, tu, ma -; RV64-NEXT: vmv.s.x v12, a0 -; RV64-NEXT: vslideup.vi v8, v12, 4 +; RV64-NEXT: vsetivli zero, 8, e64, m1, tu, ma +; RV64-NEXT: vmv.s.x v10, a0 ; RV64-NEXT: ret %v = insertelement <8 x i64> %vin, i64 %a, i32 4 ret <8 x i64> %v @@ -678,18 +674,18 @@ define <8 x i64> @insertelt_c4_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range define <8 x i64> @insertelt_c5_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range(2,2) { ; RV32-LABEL: insertelt_c5_v8xi64_exact: ; RV32: # %bb.0: -; RV32-NEXT: vsetivli zero, 2, e32, m4, ta, ma +; RV32-NEXT: vsetivli zero, 2, e32, m1, ta, ma ; RV32-NEXT: vslide1down.vx v12, v8, a0 ; RV32-NEXT: vslide1down.vx v12, v12, a1 -; RV32-NEXT: vsetivli zero, 6, e64, m4, tu, ma -; RV32-NEXT: vslideup.vi v8, v12, 5 +; RV32-NEXT: vsetivli zero, 2, e64, m1, tu, ma +; RV32-NEXT: vslideup.vi v10, v12, 1 ; RV32-NEXT: ret ; ; RV64-LABEL: insertelt_c5_v8xi64_exact: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 6, e64, m4, tu, ma +; RV64-NEXT: vsetivli zero, 2, e64, m1, tu, ma ; RV64-NEXT: vmv.s.x v12, a0 -; RV64-NEXT: vslideup.vi v8, v12, 5 +; RV64-NEXT: vslideup.vi v10, v12, 1 ; RV64-NEXT: ret %v = insertelement <8 x i64> %vin, i64 %a, i32 5 ret <8 x i64> %v