@@ -21375,19 +21375,31 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
21375
21375
21376
21376
// Helper function for comparing all bits of a vector against zero.
21377
21377
static SDValue LowerVectorAllZero(const SDLoc &DL, SDValue V, ISD::CondCode CC,
21378
+ const APInt &Mask,
21378
21379
const X86Subtarget &Subtarget,
21379
21380
SelectionDAG &DAG, X86::CondCode &X86CC) {
21380
21381
EVT VT = V.getValueType();
21382
+ assert(Mask.getBitWidth() == VT.getScalarSizeInBits() &&
21383
+ "Element Mask vs Vector bitwidth mismatch");
21381
21384
21382
21385
assert((CC == ISD::SETEQ || CC == ISD::SETNE) && "Unsupported ISD::CondCode");
21383
21386
X86CC = (CC == ISD::SETEQ ? X86::COND_E : X86::COND_NE);
21384
21387
21388
+ auto MaskBits = [&](SDValue Src) {
21389
+ if (Mask.isAllOnesValue())
21390
+ return Src;
21391
+ EVT SrcVT = Src.getValueType();
21392
+ SDValue MaskValue = DAG.getConstant(Mask, DL, SrcVT);
21393
+ return DAG.getNode(ISD::AND, DL, SrcVT, Src, MaskValue);
21394
+ };
21395
+
21385
21396
// For sub-128-bit vector, cast to (legal) integer and compare with zero.
21386
21397
if (VT.getSizeInBits() < 128) {
21387
21398
EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
21388
21399
if (!DAG.getTargetLoweringInfo().isTypeLegal(IntVT))
21389
21400
return SDValue();
21390
- return DAG.getNode(X86ISD::CMP, DL, MVT::i32, DAG.getBitcast(IntVT, V),
21401
+ return DAG.getNode(X86ISD::CMP, DL, MVT::i32,
21402
+ DAG.getBitcast(IntVT, MaskBits(V)),
21391
21403
DAG.getConstant(0, DL, IntVT));
21392
21404
}
21393
21405
@@ -21406,11 +21418,16 @@ static SDValue LowerVectorAllZero(const SDLoc &DL, SDValue V, ISD::CondCode CC,
21406
21418
bool UsePTEST = Subtarget.hasSSE41();
21407
21419
if (UsePTEST) {
21408
21420
MVT TestVT = VT.is128BitVector() ? MVT::v2i64 : MVT::v4i64;
21409
- V = DAG.getBitcast(TestVT, V );
21421
+ V = DAG.getBitcast(TestVT, MaskBits(V) );
21410
21422
return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, V, V);
21411
21423
}
21412
21424
21413
- V = DAG.getBitcast(MVT::v16i8, V);
21425
+ // Without PTEST, a masked v2i64 or-reduction is not faster than
21426
+ // scalarization.
21427
+ if (!Mask.isAllOnesValue() && VT.getScalarSizeInBits() > 32)
21428
+ return SDValue();
21429
+
21430
+ V = DAG.getBitcast(MVT::v16i8, MaskBits(V));
21414
21431
V = DAG.getNode(X86ISD::PCMPEQ, DL, MVT::v16i8, V,
21415
21432
getZeroVector(MVT::v16i8, Subtarget, DAG, DL));
21416
21433
V = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, V);
@@ -21429,6 +21446,26 @@ static SDValue MatchVectorAllZeroTest(SDValue Op, ISD::CondCode CC,
21429
21446
if (!Subtarget.hasSSE2() || !Op->hasOneUse())
21430
21447
return SDValue();
21431
21448
21449
+ // Check whether we're masking/truncating an OR-reduction result, in which
21450
+ // case track the masked bits.
21451
+ APInt Mask = APInt::getAllOnesValue(Op.getScalarValueSizeInBits());
21452
+ switch (Op.getOpcode()) {
21453
+ case ISD::TRUNCATE: {
21454
+ SDValue Src = Op.getOperand(0);
21455
+ Mask = APInt::getLowBitsSet(Src.getScalarValueSizeInBits(),
21456
+ Op.getScalarValueSizeInBits());
21457
+ Op = Src;
21458
+ break;
21459
+ }
21460
+ case ISD::AND: {
21461
+ if (auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
21462
+ Mask = Cst->getAPIntValue();
21463
+ Op = Op.getOperand(0);
21464
+ }
21465
+ break;
21466
+ }
21467
+ }
21468
+
21432
21469
SmallVector<SDValue, 8> VecIns;
21433
21470
if (Op.getOpcode() == ISD::OR && matchScalarReduction(Op, ISD::OR, VecIns)) {
21434
21471
EVT VT = VecIns[0].getValueType();
@@ -21451,8 +21488,8 @@ static SDValue MatchVectorAllZeroTest(SDValue Op, ISD::CondCode CC,
21451
21488
}
21452
21489
21453
21490
X86::CondCode CCode;
21454
- if (SDValue V =
21455
- LowerVectorAllZero(DL, VecIns.back(), CC, Subtarget, DAG, CCode)) {
21491
+ if (SDValue V = LowerVectorAllZero(DL, VecIns.back(), CC, Mask, Subtarget,
21492
+ DAG, CCode)) {
21456
21493
X86CC = DAG.getTargetConstant(CCode, DL, MVT::i8);
21457
21494
return V;
21458
21495
}
@@ -21464,7 +21501,7 @@ static SDValue MatchVectorAllZeroTest(SDValue Op, ISD::CondCode CC,
21464
21501
DAG.matchBinOpReduction(Op.getNode(), BinOp, {ISD::OR})) {
21465
21502
X86::CondCode CCode;
21466
21503
if (SDValue V =
21467
- LowerVectorAllZero(DL, Match, CC, Subtarget, DAG, CCode)) {
21504
+ LowerVectorAllZero(DL, Match, CC, Mask, Subtarget, DAG, CCode)) {
21468
21505
X86CC = DAG.getTargetConstant(CCode, DL, MVT::i8);
21469
21506
return V;
21470
21507
}
0 commit comments