Skip to content

Conversation

heiher
Copy link
Member

@heiher heiher commented Jul 18, 2025

Reported-by: tangyan [email protected]

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2025

@llvm/pr-subscribers-backend-loongarch

Author: hev (heiher)

Changes

Reported-by: tangyan <[email protected]>


Full diff: https://github.com/llvm/llvm-project/pull/149442.diff

2 Files Affected:

  • (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp (+120-100)
  • (modified) llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll (+15)
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 2378664ca8155..c870271213dc6 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -4560,6 +4560,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
   llvm_unreachable("Unexpected node type for vXi1 sign extension");
 }
 
+static SDValue
+performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG,
+                            TargetLowering::DAGCombinerInfo &DCI,
+                            const LoongArchSubtarget &Subtarget) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue Src = N->getOperand(0);
+  EVT SrcVT = Src.getValueType();
+
+  if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
+    return SDValue();
+
+  bool UseLASX;
+  unsigned Opc = ISD::DELETED_NODE;
+  EVT CmpVT = Src.getOperand(0).getValueType();
+  EVT EltVT = CmpVT.getVectorElementType();
+
+  if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
+    UseLASX = false;
+  else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
+           CmpVT.getSizeInBits() == 256)
+    UseLASX = true;
+  else
+    return SDValue();
+
+  SDValue SrcN1 = Src.getOperand(1);
+  switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
+  default:
+    break;
+  case ISD::SETEQ:
+    // x == 0 => not (vmsknez.b x)
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+    break;
+  case ISD::SETGT:
+    // x > -1 => vmskgez.b x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETGE:
+    // x >= 0 => vmskgez.b x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETLT:
+    // x < 0 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETLE:
+    // x <= -1 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETNE:
+    // x != 0 => vmsknez.b x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+    break;
+  }
+
+  if (Opc == ISD::DELETED_NODE)
+    return SDValue();
+
+  SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
+  EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
+  V = DAG.getZExtOrTrunc(V, DL, T);
+  return DAG.getBitcast(VT, V);
+}
+
 static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
                                      TargetLowering::DAGCombinerInfo &DCI,
                                      const LoongArchSubtarget &Subtarget) {
@@ -4574,110 +4648,56 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
   if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
     return SDValue();
 
-  unsigned Opc = ISD::DELETED_NODE;
   // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
-  if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) {
-    bool UseLASX;
-    EVT CmpVT = Src.getOperand(0).getValueType();
-    EVT EltVT = CmpVT.getVectorElementType();
-
-    if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
-      UseLASX = false;
-    else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
-             CmpVT.getSizeInBits() <= 256)
-      UseLASX = true;
-    else
-      return SDValue();
-
-    SDValue SrcN1 = Src.getOperand(1);
-    switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
-    default:
-      break;
-    case ISD::SETEQ:
-      // x == 0 => not (vmsknez.b x)
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
-      break;
-    case ISD::SETGT:
-      // x > -1 => vmskgez.b x
-      if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
-      break;
-    case ISD::SETGE:
-      // x >= 0 => vmskgez.b x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
-      break;
-    case ISD::SETLT:
-      // x < 0 => vmskltz.{b,h,w,d} x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
-          (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
-           EltVT == MVT::i64))
-        Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-      break;
-    case ISD::SETLE:
-      // x <= -1 => vmskltz.{b,h,w,d} x
-      if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
-          (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
-           EltVT == MVT::i64))
-        Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-      break;
-    case ISD::SETNE:
-      // x != 0 => vmsknez.b x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
-      break;
-    }
-  }
+  SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget);
+  if (Res)
+    return Res;
 
   // Generate vXi1 using [X]VMSKLTZ
-  if (Opc == ISD::DELETED_NODE) {
-    MVT SExtVT;
-    bool UseLASX = false;
-    bool PropagateSExt = false;
-    switch (SrcVT.getSimpleVT().SimpleTy) {
-    default:
-      return SDValue();
-    case MVT::v2i1:
-      SExtVT = MVT::v2i64;
-      break;
-    case MVT::v4i1:
-      SExtVT = MVT::v4i32;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v4i64;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v8i1:
-      SExtVT = MVT::v8i16;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v8i32;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v16i1:
-      SExtVT = MVT::v16i8;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v16i16;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v32i1:
-      SExtVT = MVT::v32i8;
+  MVT SExtVT;
+  unsigned Opc;
+  bool UseLASX = false;
+  bool PropagateSExt = false;
+  switch (SrcVT.getSimpleVT().SimpleTy) {
+  default:
+    return SDValue();
+  case MVT::v2i1:
+    SExtVT = MVT::v2i64;
+    break;
+  case MVT::v4i1:
+    SExtVT = MVT::v4i32;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v4i64;
       UseLASX = true;
-      break;
-    };
-    if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
-      return SDValue();
-    Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
-                        : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
-    Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-  } else {
-    Src = Src.getOperand(0);
-  }
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v8i1:
+    SExtVT = MVT::v8i16;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v8i32;
+      UseLASX = true;
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v16i1:
+    SExtVT = MVT::v16i8;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v16i16;
+      UseLASX = true;
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v32i1:
+    SExtVT = MVT::v32i8;
+    UseLASX = true;
+    break;
+  };
+  if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
+    return SDValue();
+  Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
+                      : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
+  Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
 
   SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
   EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
diff --git a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
index 0ee30120f77a6..ad57bbf9ee5c0 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
@@ -588,3 +588,18 @@ define i2 @vmsk_trunc_i64(<2 x i64> %a) {
   %res = bitcast <2 x i1> %y to i2
   ret i2 %res
 }
+
+define i4 @vmsk_eq_allzeros_v4i8(<4 x i8> %a) {
+; CHECK-LABEL: vmsk_eq_allzeros_v4i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vseqi.b $vr0, $vr0, 0
+; CHECK-NEXT:    vilvl.b $vr0, $vr0, $vr0
+; CHECK-NEXT:    vilvl.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vslli.w $vr0, $vr0, 24
+; CHECK-NEXT:    vmskltz.w $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.hu $a0, $vr0, 0
+; CHECK-NEXT:    ret
+  %1 = icmp eq <4 x i8> %a, zeroinitializer
+  %2 = bitcast <4 x i1> %1 to i4
+  ret i4 %2
+}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: performSETCC_BITCASTCombine was split out from performBITCASTCombine. Additionally, the condition was changed from CmpVT.getSizeInBits() <= 128 to CmpVT.getSizeInBits() == 128. The same applies to the 256-bit case.

@heiher heiher marked this pull request as draft July 18, 2025 06:59
@heiher heiher marked this pull request as ready for review July 18, 2025 11:17
Copy link
Member

@tangaac tangaac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@heiher heiher merged commit 8a307ae into llvm:main Jul 21, 2025
9 checks passed
@heiher heiher deleted the fix-vmsk branch July 21, 2025 08:36
@heiher heiher added this to the LLVM 21.x Release milestone Jul 21, 2025
@github-project-automation github-project-automation bot moved this to Needs Triage in LLVM Release Status Jul 21, 2025
@heiher
Copy link
Member Author

heiher commented Jul 21, 2025

/cherry-pick 8a307ae

@llvmbot
Copy link
Member

llvmbot commented Jul 21, 2025

/pull-request #149778

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

3 participants