Skip to content

[AArch64] Extend custom lowering for SVE types in @llvm.experimental.vector.compress #105515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 88 additions & 20 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1781,16 +1781,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);

// We can lower types that have <vscale x {2|4}> elements to compact.
// We can lower all legal (or smaller) SVE types to `compact`.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32,
MVT::nxv8i8, MVT::nxv8i16, MVT::nxv16i8, MVT::nxv8f16, MVT::nxv8bf16})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32,
MVT::v8i8, MVT::v8i16, MVT::v16i8})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// Histcnt is SVE2 only
Expand Down Expand Up @@ -6648,6 +6650,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
EVT ElmtVT = VecVT.getVectorElementType();
const bool IsFixedLength = VecVT.isFixedLengthVector();
const bool HasPassthru = !Passthru.isUndef();
bool CompressedViaStack = false;
unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);

Expand All @@ -6659,10 +6662,6 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();

// Only <vscale x {4|2} x {i32|i64}> supported for compact.
if (MinElmts != 2 && MinElmts != 4)
return SDValue();

// We can use the SVE register containing the NEON vector in its lowest bits.
if (IsFixedLength) {
EVT ScalableVecVT =
Expand Down Expand Up @@ -6690,19 +6689,83 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
EVT ContainerVT = getSVEContainerType(VecVT);
EVT CastVT = VecVT.changeVectorElementTypeToInteger();

// Convert to i32 or i64 for smaller types, as these are the only supported
// sizes for compact.
if (ContainerVT != VecVT) {
Vec = DAG.getBitcast(CastVT, Vec);
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
}
// These vector types aren't supported by the `compact` instruction, so
// we split and compact them as <vscale x 4 x i32>, store them on the stack,
// and then merge them again. In the other cases, emit compact directly.
SDValue Compressed;
if (VecVT == MVT::nxv8i16 || VecVT == MVT::nxv8i8 || VecVT == MVT::nxv16i8 ||
VecVT == MVT::nxv8f16 || VecVT == MVT::nxv8bf16) {
SDValue Chain = DAG.getEntryNode();
SDValue StackPtr = DAG.CreateStackTemporary(
VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
MachineFunction &MF = DAG.getMachineFunction();

bool isFloatingPoint = ElmtVT.isFloatingPoint();
if (isFloatingPoint)
Vec = DAG.getBitcast(CastVT, Vec);

EVT PartialVecVT = EVT::getVectorVT(
*DAG.getContext(), Vec.getValueType().getVectorElementType(), 4,
/*isScalable*/ true);
EVT OffsetVT = getVectorIdxTy(DAG.getDataLayout());
SDValue Offset = DAG.getConstant(0, DL, OffsetVT);

for (unsigned I = 0; I < MinElmts; I += 4) {
SDValue VectorIdx = DAG.getVectorIdxConstant(I, DL);
SDValue PartialVec =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartialVecVT, Vec, VectorIdx);
PartialVec = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv4i32, PartialVec);

SDValue PartialMask =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv4i1, Mask, VectorIdx);

SDValue PartialCompressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64),
PartialMask, PartialVec);
PartialCompressed =
DAG.getNode(ISD::TRUNCATE, DL, PartialVecVT, PartialCompressed);

SDValue OutPtr = DAG.getNode(
ISD::ADD, DL, StackPtr.getValueType(), StackPtr,
DAG.getNode(
ISD::MUL, DL, OffsetVT, Offset,
DAG.getConstant(ElmtVT.getScalarSizeInBits() / 8, DL, OffsetVT)));
Chain = DAG.getStore(Chain, DL, PartialCompressed, OutPtr,
MachinePointerInfo::getUnknownStack(MF));

SDValue PartialOffset = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, OffsetVT,
DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64),
PartialMask, PartialMask);
Offset = DAG.getNode(ISD::ADD, DL, OffsetVT, Offset, PartialOffset);
}

MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);

if (isFloatingPoint)
Compressed = DAG.getBitcast(VecVT, Compressed);

CompressedViaStack = true;
} else {
// Convert to i32 or i64 for smaller types, as these are the only supported
// sizes for compact.
if (ContainerVT != VecVT) {
Vec = DAG.getBitcast(CastVT, Vec);
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
}

SDValue Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask,
Vec);
}

// compact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
if (HasPassthru && (!ISD::isConstantSplatVectorAllZeros(Passthru.getNode()) ||
CompressedViaStack)) {
SDValue Offset = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, Mask);
Expand All @@ -6712,8 +6775,13 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
DAG.getConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64),
DAG.getConstant(0, DL, MVT::i64), Offset);

Compressed =
DAG.getNode(ISD::VSELECT, DL, VecVT, IndexMask, Compressed, Passthru);
if (ContainerVT != VecVT) {
Passthru = DAG.getBitcast(CastVT, Passthru);
Passthru = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Passthru);
}

Compressed = DAG.getNode(ISD::VSELECT, DL, Vec.getValueType(), IndexMask,
Compressed, Passthru);
}

// Extracting from a legal SVE type before truncating produces better code.
Expand All @@ -6727,7 +6795,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
}

// If we changed the element type before, we need to convert it back.
if (ContainerVT != VecVT) {
if (ContainerVT != VecVT && !CompressedViaStack) {
Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed);
Compressed = DAG.getBitcast(VecVT, Compressed);
}
Expand Down
Loading
Loading