Skip to content

Commit 02cbae4

Browse files
authored
[RISCV] Work on subreg for insert_vector_elt when vlen is known (#72666) (#73680)
If we have a constant index and a known vlen, then we can identify which registers out of a register group is being accessed. Given this, we can reuse the (slightly generalized) existing handling for working on sub-register groups. This results in all constant index extracts with known vlen becoming m1 operations. One bit of weirdness to highlight and explain: the existing code uses the VL from the original vector type, not the inner vector type. This is correct because the inner register group must be smaller than the original (possibly fixed length) vector type. Overall, this seems to a reasonable codegen tradeoff as it biases us towards immediate AVLs, which avoids needing the vsetvli form which clobbers a GPR for no real purpose. The downside is that for large fixed length vectors, we end up materializing an immediate in register for little value. We should probably generalize this idea and try to optimize the large fixed length vector case, but that can be done in separate work.
1 parent c846f8b commit 02cbae4

File tree

2 files changed

+54
-34
lines changed

2 files changed

+54
-34
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7739,17 +7739,41 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
77397739
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
77407740
}
77417741

7742-
MVT OrigContainerVT = ContainerVT;
7743-
SDValue OrigVec = Vec;
77447742
// If we know the index we're going to insert at, we can shrink Vec so that
77457743
// we're performing the scalar inserts and slideup on a smaller LMUL.
7746-
if (auto *CIdx = dyn_cast<ConstantSDNode>(Idx)) {
7747-
if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, CIdx->getZExtValue(),
7744+
MVT OrigContainerVT = ContainerVT;
7745+
SDValue OrigVec = Vec;
7746+
SDValue AlignedIdx;
7747+
if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx)) {
7748+
const unsigned OrigIdx = IdxC->getZExtValue();
7749+
// Do we know an upper bound on LMUL?
7750+
if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, OrigIdx,
77487751
DL, DAG, Subtarget)) {
77497752
ContainerVT = *ShrunkVT;
7753+
AlignedIdx = DAG.getVectorIdxConstant(0, DL);
7754+
}
7755+
7756+
// If we're compiling for an exact VLEN value, we can always perform
7757+
// the insert in m1 as we can determine the register corresponding to
7758+
// the index in the register group.
7759+
const unsigned MinVLen = Subtarget.getRealMinVLen();
7760+
const unsigned MaxVLen = Subtarget.getRealMaxVLen();
7761+
const MVT M1VT = getLMUL1VT(ContainerVT);
7762+
if (MinVLen == MaxVLen && ContainerVT.bitsGT(M1VT)) {
7763+
EVT ElemVT = VecVT.getVectorElementType();
7764+
unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
7765+
unsigned RemIdx = OrigIdx % ElemsPerVReg;
7766+
unsigned SubRegIdx = OrigIdx / ElemsPerVReg;
7767+
unsigned ExtractIdx =
7768+
SubRegIdx * M1VT.getVectorElementCount().getKnownMinValue();
7769+
AlignedIdx = DAG.getVectorIdxConstant(ExtractIdx, DL);
7770+
Idx = DAG.getVectorIdxConstant(RemIdx, DL);
7771+
ContainerVT = M1VT;
7772+
}
7773+
7774+
if (AlignedIdx)
77507775
Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec,
7751-
DAG.getVectorIdxConstant(0, DL));
7752-
}
7776+
AlignedIdx);
77537777
}
77547778

77557779
MVT XLenVT = Subtarget.getXLenVT();
@@ -7779,9 +7803,9 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
77797803
Val = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Val);
77807804
Vec = DAG.getNode(Opc, DL, ContainerVT, Vec, Val, VL);
77817805

7782-
if (ContainerVT != OrigContainerVT)
7806+
if (AlignedIdx)
77837807
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec,
7784-
Vec, DAG.getVectorIdxConstant(0, DL));
7808+
Vec, AlignedIdx);
77857809
if (!VecVT.isFixedLengthVector())
77867810
return Vec;
77877811
return convertFromScalableVector(VecVT, Vec, DAG, Subtarget);
@@ -7814,10 +7838,10 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
78147838
// Bitcast back to the right container type.
78157839
ValInVec = DAG.getBitcast(ContainerVT, ValInVec);
78167840

7817-
if (ContainerVT != OrigContainerVT)
7841+
if (AlignedIdx)
78187842
ValInVec =
78197843
DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec,
7820-
ValInVec, DAG.getVectorIdxConstant(0, DL));
7844+
ValInVec, AlignedIdx);
78217845
if (!VecVT.isFixedLengthVector())
78227846
return ValInVec;
78237847
return convertFromScalableVector(VecVT, ValInVec, DAG, Subtarget);
@@ -7849,9 +7873,9 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
78497873
SDValue Slideup = getVSlideup(DAG, Subtarget, DL, ContainerVT, Vec, ValInVec,
78507874
Idx, Mask, InsertVL, Policy);
78517875

7852-
if (ContainerVT != OrigContainerVT)
7876+
if (AlignedIdx)
78537877
Slideup = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec,
7854-
Slideup, DAG.getVectorIdxConstant(0, DL));
7878+
Slideup, AlignedIdx);
78557879
if (!VecVT.isFixedLengthVector())
78567880
return Slideup;
78577881
return convertFromScalableVector(VecVT, Slideup, DAG, Subtarget);

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -614,9 +614,8 @@ define <16 x i32> @insertelt_c3_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_ra
614614
define <16 x i32> @insertelt_c12_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) {
615615
; CHECK-LABEL: insertelt_c12_v16xi32_exact:
616616
; CHECK: # %bb.0:
617-
; CHECK-NEXT: vsetivli zero, 13, e32, m4, tu, ma
618-
; CHECK-NEXT: vmv.s.x v12, a0
619-
; CHECK-NEXT: vslideup.vi v8, v12, 12
617+
; CHECK-NEXT: vsetivli zero, 16, e32, m1, tu, ma
618+
; CHECK-NEXT: vmv.s.x v11, a0
620619
; CHECK-NEXT: ret
621620
%v = insertelement <16 x i32> %vin, i32 %a, i32 12
622621
ret <16 x i32> %v
@@ -625,9 +624,9 @@ define <16 x i32> @insertelt_c12_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r
625624
define <16 x i32> @insertelt_c13_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) {
626625
; CHECK-LABEL: insertelt_c13_v16xi32_exact:
627626
; CHECK: # %bb.0:
628-
; CHECK-NEXT: vsetivli zero, 14, e32, m4, tu, ma
627+
; CHECK-NEXT: vsetivli zero, 2, e32, m1, tu, ma
629628
; CHECK-NEXT: vmv.s.x v12, a0
630-
; CHECK-NEXT: vslideup.vi v8, v12, 13
629+
; CHECK-NEXT: vslideup.vi v11, v12, 1
631630
; CHECK-NEXT: ret
632631
%v = insertelement <16 x i32> %vin, i32 %a, i32 13
633632
ret <16 x i32> %v
@@ -636,9 +635,9 @@ define <16 x i32> @insertelt_c13_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r
636635
define <16 x i32> @insertelt_c14_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) {
637636
; CHECK-LABEL: insertelt_c14_v16xi32_exact:
638637
; CHECK: # %bb.0:
639-
; CHECK-NEXT: vsetivli zero, 15, e32, m4, tu, ma
638+
; CHECK-NEXT: vsetivli zero, 3, e32, m1, tu, ma
640639
; CHECK-NEXT: vmv.s.x v12, a0
641-
; CHECK-NEXT: vslideup.vi v8, v12, 14
640+
; CHECK-NEXT: vslideup.vi v11, v12, 2
642641
; CHECK-NEXT: ret
643642
%v = insertelement <16 x i32> %vin, i32 %a, i32 14
644643
ret <16 x i32> %v
@@ -647,9 +646,9 @@ define <16 x i32> @insertelt_c14_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r
647646
define <16 x i32> @insertelt_c15_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) {
648647
; CHECK-LABEL: insertelt_c15_v16xi32_exact:
649648
; CHECK: # %bb.0:
650-
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
649+
; CHECK-NEXT: vsetivli zero, 4, e32, m1, tu, ma
651650
; CHECK-NEXT: vmv.s.x v12, a0
652-
; CHECK-NEXT: vslideup.vi v8, v12, 15
651+
; CHECK-NEXT: vslideup.vi v11, v12, 3
653652
; CHECK-NEXT: ret
654653
%v = insertelement <16 x i32> %vin, i32 %a, i32 15
655654
ret <16 x i32> %v
@@ -658,18 +657,15 @@ define <16 x i32> @insertelt_c15_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r
658657
define <8 x i64> @insertelt_c4_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range(2,2) {
659658
; RV32-LABEL: insertelt_c4_v8xi64_exact:
660659
; RV32: # %bb.0:
661-
; RV32-NEXT: vsetivli zero, 2, e32, m4, ta, ma
662-
; RV32-NEXT: vslide1down.vx v12, v8, a0
663-
; RV32-NEXT: vslide1down.vx v12, v12, a1
664-
; RV32-NEXT: vsetivli zero, 5, e64, m4, tu, ma
665-
; RV32-NEXT: vslideup.vi v8, v12, 4
660+
; RV32-NEXT: vsetivli zero, 2, e32, m1, tu, ma
661+
; RV32-NEXT: vslide1down.vx v10, v10, a0
662+
; RV32-NEXT: vslide1down.vx v10, v10, a1
666663
; RV32-NEXT: ret
667664
;
668665
; RV64-LABEL: insertelt_c4_v8xi64_exact:
669666
; RV64: # %bb.0:
670-
; RV64-NEXT: vsetivli zero, 5, e64, m4, tu, ma
671-
; RV64-NEXT: vmv.s.x v12, a0
672-
; RV64-NEXT: vslideup.vi v8, v12, 4
667+
; RV64-NEXT: vsetivli zero, 8, e64, m1, tu, ma
668+
; RV64-NEXT: vmv.s.x v10, a0
673669
; RV64-NEXT: ret
674670
%v = insertelement <8 x i64> %vin, i64 %a, i32 4
675671
ret <8 x i64> %v
@@ -678,18 +674,18 @@ define <8 x i64> @insertelt_c4_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range
678674
define <8 x i64> @insertelt_c5_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range(2,2) {
679675
; RV32-LABEL: insertelt_c5_v8xi64_exact:
680676
; RV32: # %bb.0:
681-
; RV32-NEXT: vsetivli zero, 2, e32, m4, ta, ma
677+
; RV32-NEXT: vsetivli zero, 2, e32, m1, ta, ma
682678
; RV32-NEXT: vslide1down.vx v12, v8, a0
683679
; RV32-NEXT: vslide1down.vx v12, v12, a1
684-
; RV32-NEXT: vsetivli zero, 6, e64, m4, tu, ma
685-
; RV32-NEXT: vslideup.vi v8, v12, 5
680+
; RV32-NEXT: vsetivli zero, 2, e64, m1, tu, ma
681+
; RV32-NEXT: vslideup.vi v10, v12, 1
686682
; RV32-NEXT: ret
687683
;
688684
; RV64-LABEL: insertelt_c5_v8xi64_exact:
689685
; RV64: # %bb.0:
690-
; RV64-NEXT: vsetivli zero, 6, e64, m4, tu, ma
686+
; RV64-NEXT: vsetivli zero, 2, e64, m1, tu, ma
691687
; RV64-NEXT: vmv.s.x v12, a0
692-
; RV64-NEXT: vslideup.vi v8, v12, 5
688+
; RV64-NEXT: vslideup.vi v10, v12, 1
693689
; RV64-NEXT: ret
694690
%v = insertelement <8 x i64> %vin, i64 %a, i32 5
695691
ret <8 x i64> %v

0 commit comments

Comments
 (0)