Skip to content

Commit 5352c79

Browse files
authored
[RISCV] Add a combine to form masked.load from unit strided load (#65674)
Add a DAG combine to form a masked.load from a masked_strided_load intrinsic with stride equal to element size. This covers a couple of extra test cases, and allows us to simplify and common some existing code on the concat_vector(load, ...) to strided load transform. This is the first in a mini-patch series to try and generalize our strided load and gather matching to handle more cases, and common up different approaches to the same problems in different places.
1 parent c154ba8 commit 5352c79

File tree

3 files changed

+32
-39
lines changed

3 files changed

+32
-39
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13371,27 +13371,6 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1337113371
return SDValue();
1337213372
}
1337313373

13374-
// A special case is if the stride is exactly the width of one of the loads,
13375-
// in which case it's contiguous and can be combined into a regular vle
13376-
// without changing the element size
13377-
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
13378-
ConstStride && !Reversed &&
13379-
ConstStride->getZExtValue() == BaseLdVT.getFixedSizeInBits() / 8) {
13380-
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
13381-
BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(),
13382-
VT.getStoreSize(), Align);
13383-
// Can't do the combine if the load isn't naturally aligned with the element
13384-
// type
13385-
if (!TLI.allowsMemoryAccessForAlignment(*DAG.getContext(),
13386-
DAG.getDataLayout(), VT, *MMO))
13387-
return SDValue();
13388-
13389-
SDValue WideLoad = DAG.getLoad(VT, DL, BaseLd->getChain(), BasePtr, MMO);
13390-
for (SDValue Ld : N->ops())
13391-
DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), WideLoad);
13392-
return WideLoad;
13393-
}
13394-
1339513374
// Get the widened scalar type, e.g. v4i8 -> i64
1339613375
unsigned WideScalarBitWidth =
1339713376
BaseLdVT.getScalarSizeInBits() * BaseLdVT.getVectorNumElements();
@@ -13406,20 +13385,22 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1340613385
if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
1340713386
return SDValue();
1340813387

13409-
MVT ContainerVT = TLI.getContainerForFixedLengthVector(WideVecVT);
13410-
SDValue VL =
13411-
getDefaultVLOps(WideVecVT, ContainerVT, DL, DAG, Subtarget).second;
13412-
SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
13388+
SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
1341313389
SDValue IntID =
13414-
DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, Subtarget.getXLenVT());
13390+
DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
13391+
Subtarget.getXLenVT());
1341513392
if (Reversed)
1341613393
Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
13394+
SDValue AllOneMask =
13395+
DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
13396+
DAG.getConstant(1, DL, MVT::i1));
13397+
1341713398
SDValue Ops[] = {BaseLd->getChain(),
1341813399
IntID,
13419-
DAG.getUNDEF(ContainerVT),
13400+
DAG.getUNDEF(WideVecVT),
1342013401
BasePtr,
1342113402
Stride,
13422-
VL};
13403+
AllOneMask};
1342313404

1342413405
uint64_t MemSize;
1342513406
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
@@ -13441,11 +13422,7 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
1344113422
for (SDValue Ld : N->ops())
1344213423
DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), StridedLoad);
1344313424

13444-
// Note: Perform the bitcast before the convertFromScalableVector so we have
13445-
// balanced pairs of convertFromScalable/convertToScalable
13446-
SDValue Res = DAG.getBitcast(
13447-
TLI.getContainerForFixedLengthVector(VT.getSimpleVT()), StridedLoad);
13448-
return convertFromScalableVector(VT, Res, DAG, Subtarget);
13425+
return DAG.getBitcast(VT.getSimpleVT(), StridedLoad);
1344913426
}
1345013427

1345113428
static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
@@ -14184,6 +14161,25 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1418414161
// By default we do not combine any intrinsic.
1418514162
default:
1418614163
return SDValue();
14164+
case Intrinsic::riscv_masked_strided_load: {
14165+
MVT VT = N->getSimpleValueType(0);
14166+
auto *Load = cast<MemIntrinsicSDNode>(N);
14167+
SDValue PassThru = N->getOperand(2);
14168+
SDValue Base = N->getOperand(3);
14169+
SDValue Stride = N->getOperand(4);
14170+
SDValue Mask = N->getOperand(5);
14171+
14172+
// If the stride is equal to the element size in bytes, we can use
14173+
// a masked.load.
14174+
const unsigned ElementSize = VT.getScalarStoreSize();
14175+
if (auto *StrideC = dyn_cast<ConstantSDNode>(Stride);
14176+
StrideC && StrideC->getZExtValue() == ElementSize)
14177+
return DAG.getMaskedLoad(VT, DL, Load->getChain(), Base,
14178+
DAG.getUNDEF(XLenVT), Mask, PassThru,
14179+
Load->getMemoryVT(), Load->getMemOperand(),
14180+
ISD::UNINDEXED, ISD::NON_EXTLOAD);
14181+
return SDValue();
14182+
}
1418714183
case Intrinsic::riscv_vcpop:
1418814184
case Intrinsic::riscv_vcpop_mask:
1418914185
case Intrinsic::riscv_vfirst:

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13010,9 +13010,8 @@ define <4 x i32> @mgather_broadcast_load_masked(ptr %base, <4 x i1> %m) {
1301013010
define <4 x i32> @mgather_unit_stride_load(ptr %base) {
1301113011
; RV32-LABEL: mgather_unit_stride_load:
1301213012
; RV32: # %bb.0:
13013-
; RV32-NEXT: li a1, 4
1301413013
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
13015-
; RV32-NEXT: vlse32.v v8, (a0), a1
13014+
; RV32-NEXT: vle32.v v8, (a0)
1301613015
; RV32-NEXT: ret
1301713016
;
1301813017
; RV64V-LABEL: mgather_unit_stride_load:
@@ -13082,9 +13081,8 @@ define <4 x i32> @mgather_unit_stride_load_with_offset(ptr %base) {
1308213081
; RV32-LABEL: mgather_unit_stride_load_with_offset:
1308313082
; RV32: # %bb.0:
1308413083
; RV32-NEXT: addi a0, a0, 16
13085-
; RV32-NEXT: li a1, 4
1308613084
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
13087-
; RV32-NEXT: vlse32.v v8, (a0), a1
13085+
; RV32-NEXT: vle32.v v8, (a0)
1308813086
; RV32-NEXT: ret
1308913087
;
1309013088
; RV64V-LABEL: mgather_unit_stride_load_with_offset:

llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@ define <32 x i8> @strided_load_i8_nostride(ptr %p, <32 x i1> %m) {
5555
; CHECK-LABEL: strided_load_i8_nostride:
5656
; CHECK: # %bb.0:
5757
; CHECK-NEXT: li a1, 32
58-
; CHECK-NEXT: li a2, 1
5958
; CHECK-NEXT: vsetvli zero, a1, e8, m2, ta, ma
60-
; CHECK-NEXT: vlse8.v v8, (a0), a2, v0.t
59+
; CHECK-NEXT: vle8.v v8, (a0), v0.t
6160
; CHECK-NEXT: ret
6261
%res = call <32 x i8> @llvm.riscv.masked.strided.load.v32i8.p0.i64(<32 x i8> undef, ptr %p, i64 1, <32 x i1> %m)
6362
ret <32 x i8> %res

0 commit comments

Comments
 (0)