Skip to content

Commit 82de018

Browse files
committed
[X86][SSE] LowerVectorAllZero - add support for masked OR-reductions
If we're masking the result of an OR-reduction before comparing against zero, we can fold this into the PTEST() / MOVMSK(CMPEQ()) codegen by pre-masking the source value. This works particularly well on PTEST which performs the AND as part of its operation, but the MOVMSK variant also benefits for non-V2I64 cases. Fixes PR44781
1 parent 05c4794 commit 82de018

File tree

2 files changed

+159
-200
lines changed

2 files changed

+159
-200
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

+43-6
Original file line numberDiff line numberDiff line change
@@ -21375,19 +21375,31 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
2137521375

2137621376
// Helper function for comparing all bits of a vector against zero.
2137721377
static SDValue LowerVectorAllZero(const SDLoc &DL, SDValue V, ISD::CondCode CC,
21378+
const APInt &Mask,
2137821379
const X86Subtarget &Subtarget,
2137921380
SelectionDAG &DAG, X86::CondCode &X86CC) {
2138021381
EVT VT = V.getValueType();
21382+
assert(Mask.getBitWidth() == VT.getScalarSizeInBits() &&
21383+
"Element Mask vs Vector bitwidth mismatch");
2138121384

2138221385
assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
2138321386
X86CC = (CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE);
2138421387

21388+
auto MaskBits = [&](SDValue Src) {
21389+
if (Mask.isAllOnesValue())
21390+
return Src;
21391+
EVT SrcVT = Src.getValueType();
21392+
SDValue MaskValue = DAG.getConstant(Mask, DL, SrcVT);
21393+
return DAG.getNode(ISD::AND, DL, SrcVT, Src, MaskValue);
21394+
};
21395+
2138521396
// For sub-128-bit vector, cast to (legal) integer and compare with zero.
2138621397
if (VT.getSizeInBits() < 128) {
2138721398
EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
2138821399
if (!DAG.getTargetLoweringInfo().isTypeLegal(IntVT))
2138921400
return SDValue();
21390-
return DAG.getNode(X86ISD::CMP, DL, MVT::i32, DAG.getBitcast(IntVT, V),
21401+
return DAG.getNode(X86ISD::CMP, DL, MVT::i32,
21402+
DAG.getBitcast(IntVT, MaskBits(V)),
2139121403
DAG.getConstant(0, DL, IntVT));
2139221404
}
2139321405

@@ -21406,11 +21418,16 @@ static SDValue LowerVectorAllZero(const SDLoc &DL, SDValue V, ISD::CondCode CC,
2140621418
bool UsePTEST = Subtarget.hasSSE41();
2140721419
if (UsePTEST) {
2140821420
MVT TestVT = VT.is128BitVector() ? MVT::v2i64 : MVT::v4i64;
21409-
V = DAG.getBitcast(TestVT, V);
21421+
V = DAG.getBitcast(TestVT, MaskBits(V));
2141021422
return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, V, V);
2141121423
}
2141221424

21413-
V = DAG.getBitcast(MVT::v16i8, V);
21425+
// Without PTEST, a masked v2i64 or-reduction is not faster than
21426+
// scalarization.
21427+
if (!Mask.isAllOnesValue() && VT.getScalarSizeInBits() > 32)
21428+
return SDValue();
21429+
21430+
V = DAG.getBitcast(MVT::v16i8, MaskBits(V));
2141421431
V = DAG.getNode(X86ISD::PCMPEQ, DL, MVT::v16i8, V,
2141521432
getZeroVector(MVT::v16i8, Subtarget, DAG, DL));
2141621433
V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V);
@@ -21429,6 +21446,26 @@ static SDValue MatchVectorAllZeroTest(SDValue Op, ISD::CondCode CC,
2142921446
if (!Subtarget.hasSSE2() || !Op->hasOneUse())
2143021447
return SDValue();
2143121448

21449+
// Check whether we're masking/truncating an OR-reduction result, in which
21450+
// case track the masked bits.
21451+
APInt Mask = APInt::getAllOnesValue(Op.getScalarValueSizeInBits());
21452+
switch (Op.getOpcode()) {
21453+
case ISD::TRUNCATE: {
21454+
SDValue Src = Op.getOperand(0);
21455+
Mask = APInt::getLowBitsSet(Src.getScalarValueSizeInBits(),
21456+
Op.getScalarValueSizeInBits());
21457+
Op = Src;
21458+
break;
21459+
}
21460+
case ISD::AND: {
21461+
if (auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
21462+
Mask = Cst->getAPIntValue();
21463+
Op = Op.getOperand(0);
21464+
}
21465+
break;
21466+
}
21467+
}
21468+
2143221469
SmallVector<SDValue, 8> VecIns;
2143321470
if (Op.getOpcode() == ISD::OR && matchScalarReduction(Op, ISD::OR, VecIns)) {
2143421471
EVT VT = VecIns[0].getValueType();
@@ -21451,8 +21488,8 @@ static SDValue MatchVectorAllZeroTest(SDValue Op, ISD::CondCode CC,
2145121488
}
2145221489

2145321490
X86::CondCode CCode;
21454-
if (SDValue V =
21455-
LowerVectorAllZero(DL, VecIns.back(), CC, Subtarget, DAG, CCode)) {
21491+
if (SDValue V = LowerVectorAllZero(DL, VecIns.back(), CC, Mask, Subtarget,
21492+
DAG, CCode)) {
2145621493
X86CC = DAG.getTargetConstant(CCode, DL, MVT::i8);
2145721494
return V;
2145821495
}
@@ -21464,7 +21501,7 @@ static SDValue MatchVectorAllZeroTest(SDValue Op, ISD::CondCode CC,
2146421501
DAG.matchBinOpReduction(Op.getNode(), BinOp, {ISD::OR})) {
2146521502
X86::CondCode CCode;
2146621503
if (SDValue V =
21467-
LowerVectorAllZero(DL, Match, CC, Subtarget, DAG, CCode)) {
21504+
LowerVectorAllZero(DL, Match, CC, Mask, Subtarget, DAG, CCode)) {
2146821505
X86CC = DAG.getTargetConstant(CCode, DL, MVT::i8);
2146921506
return V;
2147021507
}

0 commit comments

Comments
 (0)