Skip to content

Conversation

llvmbot
Copy link
Member

@llvmbot llvmbot commented Jul 21, 2025

Backport 8a307ae

Requested by: @heiher

@llvmbot
Copy link
Member Author

llvmbot commented Jul 21, 2025

@tangaac What do you think about merging this PR to the release branch?

@llvmbot
Copy link
Member Author

llvmbot commented Jul 21, 2025

@llvm/pr-subscribers-backend-loongarch

Author: None (llvmbot)

Changes

Backport 8a307ae

Requested by: @heiher


Full diff: https://github.com/llvm/llvm-project/pull/149778.diff

2 Files Affected:

  • (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp (+124-97)
  • (modified) llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll (+15)
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index c47987fbf683b..12cf04bbbab56 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -4563,6 +4563,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
   llvm_unreachable("Unexpected node type for vXi1 sign extension");
 }
 
+static SDValue
+performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG,
+                            TargetLowering::DAGCombinerInfo &DCI,
+                            const LoongArchSubtarget &Subtarget) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue Src = N->getOperand(0);
+  EVT SrcVT = Src.getValueType();
+
+  if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
+    return SDValue();
+
+  bool UseLASX;
+  unsigned Opc = ISD::DELETED_NODE;
+  EVT CmpVT = Src.getOperand(0).getValueType();
+  EVT EltVT = CmpVT.getVectorElementType();
+
+  if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
+    UseLASX = false;
+  else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
+           CmpVT.getSizeInBits() == 256)
+    UseLASX = true;
+  else
+    return SDValue();
+
+  SDValue SrcN1 = Src.getOperand(1);
+  switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
+  default:
+    break;
+  case ISD::SETEQ:
+    // x == 0 => not (vmsknez.b x)
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+    break;
+  case ISD::SETGT:
+    // x > -1 => vmskgez.b x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETGE:
+    // x >= 0 => vmskgez.b x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETLT:
+    // x < 0 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETLE:
+    // x <= -1 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETNE:
+    // x != 0 => vmsknez.b x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+    break;
+  }
+
+  if (Opc == ISD::DELETED_NODE)
+    return SDValue();
+
+  SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
+  EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
+  V = DAG.getZExtOrTrunc(V, DL, T);
+  return DAG.getBitcast(VT, V);
+}
+
 static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
                                      TargetLowering::DAGCombinerInfo &DCI,
                                      const LoongArchSubtarget &Subtarget) {
@@ -4577,110 +4651,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
   if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
     return SDValue();
 
-  unsigned Opc = ISD::DELETED_NODE;
   // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
+  SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget);
+  if (Res)
+    return Res;
+
+  // Generate vXi1 using [X]VMSKLTZ
+  MVT SExtVT;
+  unsigned Opc;
+  bool UseLASX = false;
+  bool PropagateSExt = false;
+
   if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) {
-    bool UseLASX;
     EVT CmpVT = Src.getOperand(0).getValueType();
-    EVT EltVT = CmpVT.getVectorElementType();
-
-    if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
-      UseLASX = false;
-    else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
-             CmpVT.getSizeInBits() <= 256)
-      UseLASX = true;
-    else
+    if (CmpVT.getSizeInBits() > 256)
       return SDValue();
-
-    SDValue SrcN1 = Src.getOperand(1);
-    switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
-    default:
-      break;
-    case ISD::SETEQ:
-      // x == 0 => not (vmsknez.b x)
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
-      break;
-    case ISD::SETGT:
-      // x > -1 => vmskgez.b x
-      if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
-      break;
-    case ISD::SETGE:
-      // x >= 0 => vmskgez.b x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
-      break;
-    case ISD::SETLT:
-      // x < 0 => vmskltz.{b,h,w,d} x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
-          (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
-           EltVT == MVT::i64))
-        Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-      break;
-    case ISD::SETLE:
-      // x <= -1 => vmskltz.{b,h,w,d} x
-      if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
-          (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
-           EltVT == MVT::i64))
-        Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-      break;
-    case ISD::SETNE:
-      // x != 0 => vmsknez.b x
-      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
-        Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
-      break;
-    }
   }
 
-  // Generate vXi1 using [X]VMSKLTZ
-  if (Opc == ISD::DELETED_NODE) {
-    MVT SExtVT;
-    bool UseLASX = false;
-    bool PropagateSExt = false;
-    switch (SrcVT.getSimpleVT().SimpleTy) {
-    default:
-      return SDValue();
-    case MVT::v2i1:
-      SExtVT = MVT::v2i64;
-      break;
-    case MVT::v4i1:
-      SExtVT = MVT::v4i32;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v4i64;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v8i1:
-      SExtVT = MVT::v8i16;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v8i32;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v16i1:
-      SExtVT = MVT::v16i8;
-      if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
-        SExtVT = MVT::v16i16;
-        UseLASX = true;
-        PropagateSExt = true;
-      }
-      break;
-    case MVT::v32i1:
-      SExtVT = MVT::v32i8;
+  switch (SrcVT.getSimpleVT().SimpleTy) {
+  default:
+    return SDValue();
+  case MVT::v2i1:
+    SExtVT = MVT::v2i64;
+    break;
+  case MVT::v4i1:
+    SExtVT = MVT::v4i32;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v4i64;
       UseLASX = true;
-      break;
-    };
-    if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
-      return SDValue();
-    Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
-                        : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
-    Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
-  } else {
-    Src = Src.getOperand(0);
-  }
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v8i1:
+    SExtVT = MVT::v8i16;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v8i32;
+      UseLASX = true;
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v16i1:
+    SExtVT = MVT::v16i8;
+    if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+      SExtVT = MVT::v16i16;
+      UseLASX = true;
+      PropagateSExt = true;
+    }
+    break;
+  case MVT::v32i1:
+    SExtVT = MVT::v32i8;
+    UseLASX = true;
+    break;
+  };
+  if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX()))
+    return SDValue();
+  Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
+                      : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
+  Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
 
   SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
   EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
diff --git a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
index 0ee30120f77a6..ad57bbf9ee5c0 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
@@ -588,3 +588,18 @@ define i2 @vmsk_trunc_i64(<2 x i64> %a) {
   %res = bitcast <2 x i1> %y to i2
   ret i2 %res
 }
+
+define i4 @vmsk_eq_allzeros_v4i8(<4 x i8> %a) {
+; CHECK-LABEL: vmsk_eq_allzeros_v4i8:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vseqi.b $vr0, $vr0, 0
+; CHECK-NEXT:    vilvl.b $vr0, $vr0, $vr0
+; CHECK-NEXT:    vilvl.h $vr0, $vr0, $vr0
+; CHECK-NEXT:    vslli.w $vr0, $vr0, 24
+; CHECK-NEXT:    vmskltz.w $vr0, $vr0
+; CHECK-NEXT:    vpickve2gr.hu $a0, $vr0, 0
+; CHECK-NEXT:    ret
+  %1 = icmp eq <4 x i8> %a, zeroinitializer
+  %2 = bitcast <4 x i1> %1 to i4
+  ret i4 %2
+}

@tangaac
Copy link
Member

tangaac commented Jul 22, 2025

@tangaac What do you think about merging this PR to the release branch?

I agree.

@tru tru moved this from Needs Triage to Needs Merge in LLVM Release Status Jul 22, 2025
@tru tru merged commit 7fd16ee into llvm:release/21.x Jul 22, 2025
@github-project-automation github-project-automation bot moved this from Needs Merge to Done in LLVM Release Status Jul 22, 2025
Copy link

@heiher (or anyone else). If you would like to add a note about this fix in the release notes (completely optional). Please reply to this comment with a one or two sentence description of the fix. When you are done, please add the release:note label to this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

4 participants