-
Notifications
You must be signed in to change notification settings - Fork 14.8k
release/21.x: [LoongArch] Fix failure to widen operand for [X]VMSK{LT,GE,NE}Z
(#149442)
#149778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@tangaac What do you think about merging this PR to the release branch? |
@llvm/pr-subscribers-backend-loongarch Author: None (llvmbot) ChangesBackport 8a307ae Requested by: @heiher Full diff: https://github.com/llvm/llvm-project/pull/149778.diff 2 Files Affected:
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index c47987fbf683b..12cf04bbbab56 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -4563,6 +4563,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
llvm_unreachable("Unexpected node type for vXi1 sign extension");
}
+static SDValue
+performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const LoongArchSubtarget &Subtarget) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ SDValue Src = N->getOperand(0);
+ EVT SrcVT = Src.getValueType();
+
+ if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
+ return SDValue();
+
+ bool UseLASX;
+ unsigned Opc = ISD::DELETED_NODE;
+ EVT CmpVT = Src.getOperand(0).getValueType();
+ EVT EltVT = CmpVT.getVectorElementType();
+
+ if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
+ UseLASX = false;
+ else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
+ CmpVT.getSizeInBits() == 256)
+ UseLASX = true;
+ else
+ return SDValue();
+
+ SDValue SrcN1 = Src.getOperand(1);
+ switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
+ default:
+ break;
+ case ISD::SETEQ:
+ // x == 0 => not (vmsknez.b x)
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+ break;
+ case ISD::SETGT:
+ // x > -1 => vmskgez.b x
+ if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+ break;
+ case ISD::SETGE:
+ // x >= 0 => vmskgez.b x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+ break;
+ case ISD::SETLT:
+ // x < 0 => vmskltz.{b,h,w,d} x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+ EltVT == MVT::i64))
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+ break;
+ case ISD::SETLE:
+ // x <= -1 => vmskltz.{b,h,w,d} x
+ if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+ EltVT == MVT::i64))
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+ break;
+ case ISD::SETNE:
+ // x != 0 => vmsknez.b x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+ break;
+ }
+
+ if (Opc == ISD::DELETED_NODE)
+ return SDValue();
+
+ SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
+ EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
+ V = DAG.getZExtOrTrunc(V, DL, T);
+ return DAG.getBitcast(VT, V);
+}
+
static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
@@ -4577,110 +4651,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
return SDValue();
- unsigned Opc = ISD::DELETED_NODE;
// Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
+ SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget);
+ if (Res)
+ return Res;
+
+ // Generate vXi1 using [X]VMSKLTZ
+ MVT SExtVT;
+ unsigned Opc;
+ bool UseLASX = false;
+ bool PropagateSExt = false;
+
if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) {
- bool UseLASX;
EVT CmpVT = Src.getOperand(0).getValueType();
- EVT EltVT = CmpVT.getVectorElementType();
-
- if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
- UseLASX = false;
- else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
- CmpVT.getSizeInBits() <= 256)
- UseLASX = true;
- else
+ if (CmpVT.getSizeInBits() > 256)
return SDValue();
-
- SDValue SrcN1 = Src.getOperand(1);
- switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
- default:
- break;
- case ISD::SETEQ:
- // x == 0 => not (vmsknez.b x)
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
- break;
- case ISD::SETGT:
- // x > -1 => vmskgez.b x
- if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
- break;
- case ISD::SETGE:
- // x >= 0 => vmskgez.b x
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
- break;
- case ISD::SETLT:
- // x < 0 => vmskltz.{b,h,w,d} x
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
- EltVT == MVT::i64))
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
- break;
- case ISD::SETLE:
- // x <= -1 => vmskltz.{b,h,w,d} x
- if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
- EltVT == MVT::i64))
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
- break;
- case ISD::SETNE:
- // x != 0 => vmsknez.b x
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
- break;
- }
}
- // Generate vXi1 using [X]VMSKLTZ
- if (Opc == ISD::DELETED_NODE) {
- MVT SExtVT;
- bool UseLASX = false;
- bool PropagateSExt = false;
- switch (SrcVT.getSimpleVT().SimpleTy) {
- default:
- return SDValue();
- case MVT::v2i1:
- SExtVT = MVT::v2i64;
- break;
- case MVT::v4i1:
- SExtVT = MVT::v4i32;
- if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
- SExtVT = MVT::v4i64;
- UseLASX = true;
- PropagateSExt = true;
- }
- break;
- case MVT::v8i1:
- SExtVT = MVT::v8i16;
- if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
- SExtVT = MVT::v8i32;
- UseLASX = true;
- PropagateSExt = true;
- }
- break;
- case MVT::v16i1:
- SExtVT = MVT::v16i8;
- if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
- SExtVT = MVT::v16i16;
- UseLASX = true;
- PropagateSExt = true;
- }
- break;
- case MVT::v32i1:
- SExtVT = MVT::v32i8;
+ switch (SrcVT.getSimpleVT().SimpleTy) {
+ default:
+ return SDValue();
+ case MVT::v2i1:
+ SExtVT = MVT::v2i64;
+ break;
+ case MVT::v4i1:
+ SExtVT = MVT::v4i32;
+ if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+ SExtVT = MVT::v4i64;
UseLASX = true;
- break;
- };
- if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
- return SDValue();
- Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
- : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
- } else {
- Src = Src.getOperand(0);
- }
+ PropagateSExt = true;
+ }
+ break;
+ case MVT::v8i1:
+ SExtVT = MVT::v8i16;
+ if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+ SExtVT = MVT::v8i32;
+ UseLASX = true;
+ PropagateSExt = true;
+ }
+ break;
+ case MVT::v16i1:
+ SExtVT = MVT::v16i8;
+ if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+ SExtVT = MVT::v16i16;
+ UseLASX = true;
+ PropagateSExt = true;
+ }
+ break;
+ case MVT::v32i1:
+ SExtVT = MVT::v32i8;
+ UseLASX = true;
+ break;
+ };
+ if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX()))
+ return SDValue();
+ Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
+ : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
diff --git a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
index 0ee30120f77a6..ad57bbf9ee5c0 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
@@ -588,3 +588,18 @@ define i2 @vmsk_trunc_i64(<2 x i64> %a) {
%res = bitcast <2 x i1> %y to i2
ret i2 %res
}
+
+define i4 @vmsk_eq_allzeros_v4i8(<4 x i8> %a) {
+; CHECK-LABEL: vmsk_eq_allzeros_v4i8:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vseqi.b $vr0, $vr0, 0
+; CHECK-NEXT: vilvl.b $vr0, $vr0, $vr0
+; CHECK-NEXT: vilvl.h $vr0, $vr0, $vr0
+; CHECK-NEXT: vslli.w $vr0, $vr0, 24
+; CHECK-NEXT: vmskltz.w $vr0, $vr0
+; CHECK-NEXT: vpickve2gr.hu $a0, $vr0, 0
+; CHECK-NEXT: ret
+ %1 = icmp eq <4 x i8> %a, zeroinitializer
+ %2 = bitcast <4 x i1> %1 to i4
+ ret i4 %2
+}
|
I agree. |
…vm#149442) Reported-by: tangyan <[email protected]> (cherry picked from commit 8a307ae)
@heiher (or anyone else). If you would like to add a note about this fix in the release notes (completely optional). Please reply to this comment with a one or two sentence description of the fix. When you are done, please add the release:note label to this PR. |
Backport 8a307ae
Requested by: @heiher