Skip to content

[msan] Reland with even more improvement: Improve packed multiply-add instrumentation #153353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 15, 2025

Conversation

thurstond
Copy link
Contributor

@thurstond thurstond commented Aug 13, 2025

This reverts commit cf00284 i.e., relands ba603b5. It was reverted because it was subtly wrong: multiplying an uninitialized zero should not result in an initialized zero.

This reland fixes the issue by using instrumentation analogous to visitAnd (bitwise AND of an initialized zero and an uninitialized value results in an initialized value). Additionally, this reland expands a test case; fixes the commit message; and optimizes the change to avoid the need for horizontalReduce.

The current instrumentation has false positives: it does not take into account that multiplying an initialized zero value with an uninitialized value results in an initialized zero value This change fixes the issue during the multiplication step. The horizontal add step is modeled using bitwise OR.

Future work can apply this improved handler to the AVX512 equivalent intrinsics (x86_avx512_pmaddw_d_512, x86_avx512_pmaddubs_w_512.) and AVX VNNI intrinsics.

…3343)

This reverts commit cf00284 i.e.,
relands ba603b5. It was reverted
because it was subtly wrong: multiplying an uninitialized zero should not
result in an initialized zero. This reland fixes the issue by using
instrumentation analogous to visitAnd (bitwise AND of an initialized
zero and an uninitialized value results in an initialized value), and
expands a test case.

Original commit message:
The current instrumentation has false positives: if there is a single uninitialized bit in any of the operands, the entire output is poisoned. This does not take into account that multiplying an uninitialized value with zero results in an initialized zero value.

This step allows elements that are zero to clear the corresponding shadow during the multiplication step. The horizontal add step and accumulation step (if any) are modeled using bitwise OR.

Future work can apply this improved handler to the AVX512 equivalent intrinsics (x86_avx512_pmaddw_d_512, x86_avx512_pmaddubs_w_512.) and AVX VNNI intrinsics.
@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2025

@llvm/pr-subscribers-compiler-rt-sanitizer

Author: Thurston Dang (thurstond)

Changes

This reverts commit cf00284 i.e., relands ba603b5. It was reverted because it was subtly wrong: multiplying an uninitialized zero should not result in an initialized zero.

This reland fixes the issue by using instrumentation analogous to visitAnd (bitwise AND of an initialized zero and an uninitialized value results in an initialized value). Additionally, this reland expands a test case and fixes the commit message.

The current instrumentation has false positives: it does not take into account that multiplying an initialized zero value with an uninitialized value results in an initialized zero value This change fixes the issue during the multiplication step. The horizontal add step is modeled using bitwise OR.

Future work can apply this improved handler to the AVX512 equivalent intrinsics (x86_avx512_pmaddw_d_512, x86_avx512_pmaddubs_w_512.) and AVX VNNI intrinsics.


Patch is 72.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153353.diff

11 Files Affected:

  • (modified) compiler-rt/lib/msan/tests/msan_test.cpp (+26-1)
  • (modified) llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp (+126-19)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll (+45-15)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll (+66-50)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics.ll (+66-50)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/mmx-intrinsics.ll (+49-19)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/sse2-intrinsics-x86.ll (+15-5)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/i386/avx2-intrinsics-i386.ll (+46-16)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/i386/mmx-intrinsics.ll (+49-19)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/i386/sse2-intrinsics-i386.ll (+15-5)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/vector_arith.ll (+35-10)
diff --git a/compiler-rt/lib/msan/tests/msan_test.cpp b/compiler-rt/lib/msan/tests/msan_test.cpp
index d1c481483dfad..bcf048f04c5c2 100644
--- a/compiler-rt/lib/msan/tests/msan_test.cpp
+++ b/compiler-rt/lib/msan/tests/msan_test.cpp
@@ -4271,14 +4271,39 @@ TEST(VectorSadTest, sse2_psad_bw) {
 }
 
 TEST(VectorMaddTest, mmx_pmadd_wd) {
-  V4x16 a = {Poisoned<U2>(), 1, 2, 3};
+  V4x16 a = {Poisoned<U2>(0), 1, 2, 3};
   V4x16 b = {100, 101, 102, 103};
   V2x32 c = _mm_madd_pi16(a, b);
+  // Multiply step:
+  //    {Poison * 100, 1 * 101, 2 * 102, 3 * 103}
+  // == {Poison,       1 * 101, 2 * 102, 3 * 103}
+  //    Notice that for the poisoned value, we ignored the concrete zero value.
+  //
+  // Horizontal add step:
+  //    {Poison + 1 * 101, 2 * 102 + 3 * 103}
+  // == {Poison,           2 * 102 + 3 * 103}
 
   EXPECT_POISONED(c[0]);
   EXPECT_NOT_POISONED(c[1]);
 
   EXPECT_EQ((unsigned)(2 * 102 + 3 * 103), c[1]);
+
+  V4x16 d = {Poisoned<U2>(0), 1, 0, 3};
+  V4x16 e = {100, 101, Poisoned<U2>(102), 103};
+  V2x32 f = _mm_madd_pi16(a, b);
+  // Multiply step:
+  //    {Poison * 100, 1 * 101, 0 * Poison, 3 * 103}
+  // == {Poison,       1 * 101, 0         , 3 * 103}
+  //    Notice that 0 * Poison == 0.
+  //
+  // Horizontal add step:
+  //    {Poison + 1 * 101, 0 + 3 * 103}
+  // == {Poison,           3 * 103}
+
+  EXPECT_POISONED(f[0]);
+  EXPECT_NOT_POISONED(f[1]);
+
+  EXPECT_EQ((unsigned)(3 * 103), f[1]);
 }
 
 TEST(VectorCmpTest, mm_cmpneq_ps) {
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 21bd4164385ab..d3e686c26f188 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -3641,9 +3641,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
-  // Get an MMX-sized vector type.
-  Type *getMMXVectorTy(unsigned EltSizeInBits) {
-    const unsigned X86_MMXSizeInBits = 64;
+  // Get an MMX-sized (64-bit) vector type, or optionally, other sized
+  // vectors.
+  Type *getMMXVectorTy(unsigned EltSizeInBits,
+                       unsigned X86_MMXSizeInBits = 64) {
     assert(EltSizeInBits != 0 && (X86_MMXSizeInBits % EltSizeInBits) == 0 &&
            "Illegal MMX vector element size");
     return FixedVectorType::get(IntegerType::get(*MS.C, EltSizeInBits),
@@ -3843,20 +3844,99 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
-  // Instrument multiply-add intrinsic.
-  void handleVectorPmaddIntrinsic(IntrinsicInst &I,
-                                  unsigned MMXEltSizeInBits = 0) {
-    Type *ResTy =
-        MMXEltSizeInBits ? getMMXVectorTy(MMXEltSizeInBits * 2) : I.getType();
+  // Instrument multiply-add intrinsics.
+  //
+  // e.g., Two operands:
+  //         <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
+  //         <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
+  //
+  //       Three operands are not implemented yet:
+  //         <4 x i32> @llvm.x86.avx512.vpdpbusd.128
+  //                       (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
+  //         (the result of multiply-add'ing %a and %b is accumulated with %s)
+  void handleVectorPmaddIntrinsic(IntrinsicInst &I, unsigned ReductionFactor,
+                                  unsigned EltSizeInBits = 0) {
     IRBuilder<> IRB(&I);
-    auto *Shadow0 = getShadow(&I, 0);
-    auto *Shadow1 = getShadow(&I, 1);
-    Value *S = IRB.CreateOr(Shadow0, Shadow1);
-    S = IRB.CreateBitCast(S, ResTy);
-    S = IRB.CreateSExt(IRB.CreateICmpNE(S, Constant::getNullValue(ResTy)),
-                       ResTy);
-    S = IRB.CreateBitCast(S, getShadowTy(&I));
-    setShadow(&I, S);
+
+    [[maybe_unused]] FixedVectorType *ReturnType =
+        cast<FixedVectorType>(I.getType());
+    assert(isa<FixedVectorType>(ReturnType));
+
+    assert(I.arg_size() == 2);
+
+    // Vectors A and B, and shadows
+    Value *Va = I.getOperand(0);
+    Value *Vb = I.getOperand(1);
+
+    Value *Sa = getShadow(&I, 0);
+    Value *Sb = getShadow(&I, 1);
+
+    FixedVectorType *ParamType =
+        cast<FixedVectorType>(I.getArgOperand(0)->getType());
+    assert(ParamType == I.getArgOperand(1)->getType());
+
+    assert(ParamType->getPrimitiveSizeInBits() ==
+           ReturnType->getPrimitiveSizeInBits());
+
+    // Step 1: instrument multiplication of corresponding vector elements
+    if (EltSizeInBits) {
+      ParamType = cast<FixedVectorType>(
+          getMMXVectorTy(EltSizeInBits, ParamType->getPrimitiveSizeInBits()));
+
+      Va = IRB.CreateBitCast(Va, ParamType);
+      Vb = IRB.CreateBitCast(Vb, ParamType);
+
+      Sa = IRB.CreateBitCast(Sa, getShadowTy(ParamType));
+      Sb = IRB.CreateBitCast(Sb, getShadowTy(ParamType));
+    } else {
+      assert(ParamType->getNumElements() ==
+             ReturnType->getNumElements() * ReductionFactor);
+    }
+
+    // Multiplying an *initialized* zero by an uninitialized element results in
+    // an initialized zero element.
+    //
+    // This is analogous to bitwise AND, where "AND" of 0 and a poisoned value
+    // results in an unpoisoned value. We can therefore adapt the visitAnd()
+    // instrumentation:
+    //   OutShadow =   (SaNonZero & SbNonZero)
+    //               | (VaNonZero & SbNonZero)
+    //               | (SaNonZero & VbNonZero)
+    //   where non-zero is checked on a per-element basis.
+    Value *SZero = Constant::getNullValue(Va->getType());
+    Value *VZero = Constant::getNullValue(Sa->getType());
+    Value *SaNonZero = IRB.CreateICmpNE(Sa, SZero);
+    Value *SbNonZero = IRB.CreateICmpNE(Sb, SZero);
+    Value *VaNonZero = IRB.CreateICmpNE(Va, VZero);
+    Value *VbNonZero = IRB.CreateICmpNE(Vb, VZero);
+
+    Value *SaAndSbNonZero = IRB.CreateAnd(SaNonZero, SbNonZero);
+    Value *VaAndSbNonZero = IRB.CreateAnd(VaNonZero, SbNonZero);
+    Value *SaAndVbNonZero = IRB.CreateAnd(SaNonZero, VbNonZero);
+
+    // Each element of the vector is represented by a single bit (poisoned or
+    // not) e.g., <8 x i1>.
+    Value *ComboNonZero =
+        IRB.CreateOr({SaAndSbNonZero, VaAndSbNonZero, SaAndVbNonZero});
+
+    // Extend <8 x i1> to <8 x i16>.
+    // (The real pmadd intrinsic would have computed intermediate values of
+    // <8 x i32>, but that is irrelevant for our shadow purposes because we
+    // consider each element to be either fully initialized or fully
+    // uninitialized.)
+    ComboNonZero = IRB.CreateSExt(ComboNonZero, Sa->getType());
+
+    // Step 2: instrument horizontal add
+    // e.g., collapse <8 x i16> into <4 x i16> (reduction factor == 2)
+    //                <16 x i8> into <4 x i8>  (reduction factor == 4)
+    Value *OutShadow =
+        horizontalReduce(I, ReductionFactor, ComboNonZero, nullptr);
+
+    // Extend to <4 x i32>.
+    // For MMX, cast it back to <1 x i64>.
+    OutShadow = CreateShadowCast(IRB, OutShadow, getShadowTy(&I));
+
+    setShadow(&I, OutShadow);
     setOriginForNaryOp(I);
   }
 
@@ -5391,19 +5471,46 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
       handleVectorSadIntrinsic(I);
       break;
 
+    // Multiply and Add Packed Words
+    //   < 4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16>, <8 x i16>)
+    //   < 8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>)
+    //   <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16>, <32 x i16>)
+    //
+    // Multiply and Add Packed Signed and Unsigned Bytes
+    //   < 8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8>, <16 x i8>)
+    //   <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8>, <32 x i8>)
+    //   <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8>, <64 x i8>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   < 4 x i32> @llvm.x86.avx512.mask.pmaddw.d.128
+    //                  (<8 x i16>, <8 x i16>, <4 x i32>, i8)
+    //   < 8 x i32> @llvm.x86.avx512.mask.pmaddw.d.256
+    //                  (<16 x i16>, <16 x i16>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.pmaddw.d.512
+    //                  (<32 x i16>, <32 x i16>, <16 x i32>, i16)
+    //   < 8 x i16> @llvm.x86.avx512.mask.pmaddubs.w.128
+    //                  (<16 x i8>, <16 x i8>, <8 x i16>, i8)
+    //   <16 x i16> @llvm.x86.avx512.mask.pmaddubs.w.256
+    //                  (<32 x i8>, <32 x i8>, <16 x i16>, i16)
+    //   <32 x i16> @llvm.x86.avx512.mask.pmaddubs.w.512
+    //                  (<64 x i8>, <64 x i8>, <32 x i16>, i32)
     case Intrinsic::x86_sse2_pmadd_wd:
     case Intrinsic::x86_avx2_pmadd_wd:
+    case Intrinsic::x86_avx512_pmaddw_d_512:
     case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
     case Intrinsic::x86_avx2_pmadd_ub_sw:
-      handleVectorPmaddIntrinsic(I);
+    case Intrinsic::x86_avx512_pmaddubs_w_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2);
       break;
 
+    // <1 x i64> @llvm.x86.ssse3.pmadd.ub.sw(<1 x i64>, <1 x i64>)
     case Intrinsic::x86_ssse3_pmadd_ub_sw:
-      handleVectorPmaddIntrinsic(I, 8);
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/8);
       break;
 
+    // <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64>, <1 x i64>)
     case Intrinsic::x86_mmx_pmadd_wd:
-      handleVectorPmaddIntrinsic(I, 16);
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
       break;
 
     case Intrinsic::x86_sse_cmp_ss:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll
index f916130fe53e5..26a6a3bdb5c0f 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll
@@ -140,11 +140,21 @@ define <8 x i32> @test_x86_avx2_pmadd_wd(<16 x i16> %a0, <16 x i16> %a1) #0 {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i16>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP3:%.*]] = or <16 x i16> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <16 x i16> [[TMP3]] to <8 x i32>
-; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <8 x i32> [[TMP4]], zeroinitializer
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i1> [[TMP5]] to <8 x i32>
-; CHECK-NEXT:    [[RES:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> [[A0:%.*]], <16 x i16> [[A1:%.*]])
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne <16 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <16 x i16> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <16 x i16> [[A0:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ne <16 x i16> [[A1:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = and <16 x i1> [[TMP4]], [[TMP5]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <16 x i1> [[TMP12]], [[TMP5]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <16 x i1> [[TMP4]], [[TMP13]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <16 x i1> [[TMP11]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <16 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP7:%.*]] = sext <16 x i1> [[TMP17]] to <16 x i16>
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <16 x i16> [[TMP7]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <16 x i16> [[TMP7]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP10:%.*]] = or <8 x i16> [[TMP8]], [[TMP9]]
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <8 x i16> [[TMP10]] to <8 x i32>
+; CHECK-NEXT:    [[RES:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> [[A0]], <16 x i16> [[A1]])
 ; CHECK-NEXT:    store <8 x i32> [[TMP6]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <8 x i32> [[RES]]
 ;
@@ -677,11 +687,21 @@ define <16 x i16> @test_x86_avx2_pmadd_ub_sw(<32 x i8> %a0, <32 x i8> %a1) #0 {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <32 x i8>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <32 x i8>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP3:%.*]] = or <32 x i8> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <32 x i8> [[TMP3]] to <16 x i16>
-; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <16 x i16> [[TMP4]], zeroinitializer
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i1> [[TMP5]] to <16 x i16>
-; CHECK-NEXT:    [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0:%.*]], <32 x i8> [[A1:%.*]])
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne <32 x i8> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <32 x i8> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <32 x i8> [[A0:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ne <32 x i8> [[A1:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = and <32 x i1> [[TMP4]], [[TMP5]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <32 x i1> [[TMP12]], [[TMP5]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <32 x i1> [[TMP4]], [[TMP13]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <32 x i1> [[TMP11]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <32 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP7:%.*]] = sext <32 x i1> [[TMP17]] to <32 x i8>
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <32 x i8> [[TMP7]], <32 x i8> poison, <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30>
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <32 x i8> [[TMP7]], <32 x i8> poison, <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31>
+; CHECK-NEXT:    [[TMP10:%.*]] = or <16 x i8> [[TMP8]], [[TMP9]]
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i8> [[TMP10]] to <16 x i16>
+; CHECK-NEXT:    [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0]], <32 x i8> [[A1]])
 ; CHECK-NEXT:    store <16 x i16> [[TMP6]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x i16> [[RES]]
 ;
@@ -706,11 +726,21 @@ define <16 x i16> @test_x86_avx2_pmadd_ub_sw_load_op0(ptr %ptr, <32 x i8> %a1) #
 ; CHECK-NEXT:    [[TMP6:%.*]] = xor i64 [[TMP5]], 87960930222080
 ; CHECK-NEXT:    [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr
 ; CHECK-NEXT:    [[_MSLD:%.*]] = load <32 x i8>, ptr [[TMP7]], align 32
-; CHECK-NEXT:    [[TMP8:%.*]] = or <32 x i8> [[_MSLD]], [[TMP2]]
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <32 x i8> [[TMP8]] to <16 x i16>
-; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <16 x i16> [[TMP9]], zeroinitializer
-; CHECK-NEXT:    [[TMP11:%.*]] = sext <16 x i1> [[TMP10]] to <16 x i16>
-; CHECK-NEXT:    [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0]], <32 x i8> [[A1:%.*]])
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp ne <32 x i8> [[_MSLD]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <32 x i8> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP17:%.*]] = icmp ne <32 x i8> [[A0]], zeroinitializer
+; CHECK-NEXT:    [[TMP18:%.*]] = icmp ne <32 x i8> [[A1:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP16:%.*]] = and <32 x i1> [[TMP9]], [[TMP10]]
+; CHECK-NEXT:    [[TMP19:%.*]] = and <32 x i1> [[TMP17]], [[TMP10]]
+; CHECK-NEXT:    [[TMP20:%.*]] = and <32 x i1> [[TMP9]], [[TMP18]]
+; CHECK-NEXT:    [[TMP21:%.*]] = or <32 x i1> [[TMP16]], [[TMP19]]
+; CHECK-NEXT:    [[TMP22:%.*]] = or <32 x i1> [[TMP21]], [[TMP20]]
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <32 x i1> [[TMP22]] to <32 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <32 x i8> [[TMP12]], <32 x i8> poison, <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30>
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <32 x i8> [[TMP12]], <32 x i8> poison, <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31>
+; CHECK-NEXT:    [[TMP15:%.*]] = or <16 x i8> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP11:%.*]] = zext <16 x i8> [[TMP15]] to <16 x i16>
+; CHECK-NEXT:    [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0]], <32 x i8> [[A1]])
 ; CHECK-NEXT:    store <16 x i16> [[TMP11]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x i16> [[RES]]
 ;
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll
index 02df9c49a010b..23a4b952281ae 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll
@@ -4930,18 +4930,22 @@ define <32 x i16> @test_int_x86_avx512_pmaddubs_w_512(<64 x i8> %x0, <64 x i8> %
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <64 x i8>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <64 x i8>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <64 x i8> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP3]], 0
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <64 x i8> [[TMP2]] to i512
-; CHECK-NEXT:    [[_MSCMP1:%.*]] = icmp ne i512 [[TMP4]], 0
-; CHECK-NEXT:    [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
-; CHECK-NEXT:    br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF1]]
-; CHECK:       5:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR7]]
-; CHECK-NEXT:    unreachable
-; CHECK:       6:
-; CHECK-NEXT:    [[TMP7:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> [[X0:%.*]], <64 x i8> [[X1:%.*]])
-; CHECK-NEXT:    store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp ne <64 x i8> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne <64 x i8> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <64 x i8> [[X0:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne <64 x i8> [[X1:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP17:%.*]] = and <64 x i1> [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = and <64 x i1> [[TMP5]], [[TMP4]]
+; CHECK-NEXT:    [[TMP9:%.*]] = and <64 x i1> [[TMP3]], [[TMP6]]
+; CHECK-NEXT:    [[TMP10:%.*]] = or <64 x i1> [[TMP17]], [[TMP8]]
+; CHECK-NEXT:    [[TMP11:%.*]] = or <64 x i1> [[TMP10]], [[TMP9]]
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <64 x i1> [[TMP11]] to <64 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <64 x i8> [[TMP12]], <64 x i8> poison, <32 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30, i32 32, i32 34, i32 36, i32 38, i32 40, i32 42, i32 44, i32 46, i32 48, i32 50, i32 52, i32 54, i32 56, i32 58, i32 60, i32 62>
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <64 x i8> [[TMP12]], <64 x i8> poison, <32 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31, i32 33, i32 35, i32 37, i32 39, i32 41, i32 43, i32 45, i32 47, i32 49, i32 51, i32 53, i32 55, i32 57, i32 59, i32 61, i32 63>
+; CHECK-NEXT:    [[TMP15:%.*]] = or <32 x i8> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP16:%.*]] = zext <32 x i8> [[TMP15]] to <32 x i16>
+; CHECK-NEXT:    [[TMP7:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> [[X0]], <64 x i8> [[X1]])
+; CHECK-NEXT:    store <32 x i16> [[TMP16]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <32 x i16> [[TMP7]]
 ;
   %res = call <32 x i16> @llvm.x86.avx512.mask.pmaddubs.w.512(<64 x i8> %x0, <64 x i8> %x1, <32 x i16> %x2, i32 -1)
@@ -4955,22 +4959,26 @@ define <32 x i16> @test_int_x86_avx512_mask_pmaddubs_w_512(<64 x i8> %x0, <64 x
 ; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 192) to ptr), align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load <32 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP5:%.*]] = bi...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Thurston Dang (thurstond)

Changes

This reverts commit cf00284 i.e., relands ba603b5. It was reverted because it was subtly wrong: multiplying an uninitialized zero should not result in an initialized zero.

This reland fixes the issue by using instrumentation analogous to visitAnd (bitwise AND of an initialized zero and an uninitialized value results in an initialized value). Additionally, this reland expands a test case and fixes the commit message.

The current instrumentation has false positives: it does not take into account that multiplying an initialized zero value with an uninitialized value results in an initialized zero value This change fixes the issue during the multiplication step. The horizontal add step is modeled using bitwise OR.

Future work can apply this improved handler to the AVX512 equivalent intrinsics (x86_avx512_pmaddw_d_512, x86_avx512_pmaddubs_w_512.) and AVX VNNI intrinsics.


Patch is 72.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153353.diff

11 Files Affected:

  • (modified) compiler-rt/lib/msan/tests/msan_test.cpp (+26-1)
  • (modified) llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp (+126-19)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll (+45-15)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll (+66-50)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics.ll (+66-50)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/mmx-intrinsics.ll (+49-19)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/sse2-intrinsics-x86.ll (+15-5)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/i386/avx2-intrinsics-i386.ll (+46-16)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/i386/mmx-intrinsics.ll (+49-19)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/i386/sse2-intrinsics-i386.ll (+15-5)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/vector_arith.ll (+35-10)
diff --git a/compiler-rt/lib/msan/tests/msan_test.cpp b/compiler-rt/lib/msan/tests/msan_test.cpp
index d1c481483dfad..bcf048f04c5c2 100644
--- a/compiler-rt/lib/msan/tests/msan_test.cpp
+++ b/compiler-rt/lib/msan/tests/msan_test.cpp
@@ -4271,14 +4271,39 @@ TEST(VectorSadTest, sse2_psad_bw) {
 }
 
 TEST(VectorMaddTest, mmx_pmadd_wd) {
-  V4x16 a = {Poisoned<U2>(), 1, 2, 3};
+  V4x16 a = {Poisoned<U2>(0), 1, 2, 3};
   V4x16 b = {100, 101, 102, 103};
   V2x32 c = _mm_madd_pi16(a, b);
+  // Multiply step:
+  //    {Poison * 100, 1 * 101, 2 * 102, 3 * 103}
+  // == {Poison,       1 * 101, 2 * 102, 3 * 103}
+  //    Notice that for the poisoned value, we ignored the concrete zero value.
+  //
+  // Horizontal add step:
+  //    {Poison + 1 * 101, 2 * 102 + 3 * 103}
+  // == {Poison,           2 * 102 + 3 * 103}
 
   EXPECT_POISONED(c[0]);
   EXPECT_NOT_POISONED(c[1]);
 
   EXPECT_EQ((unsigned)(2 * 102 + 3 * 103), c[1]);
+
+  V4x16 d = {Poisoned<U2>(0), 1, 0, 3};
+  V4x16 e = {100, 101, Poisoned<U2>(102), 103};
+  V2x32 f = _mm_madd_pi16(a, b);
+  // Multiply step:
+  //    {Poison * 100, 1 * 101, 0 * Poison, 3 * 103}
+  // == {Poison,       1 * 101, 0         , 3 * 103}
+  //    Notice that 0 * Poison == 0.
+  //
+  // Horizontal add step:
+  //    {Poison + 1 * 101, 0 + 3 * 103}
+  // == {Poison,           3 * 103}
+
+  EXPECT_POISONED(f[0]);
+  EXPECT_NOT_POISONED(f[1]);
+
+  EXPECT_EQ((unsigned)(3 * 103), f[1]);
 }
 
 TEST(VectorCmpTest, mm_cmpneq_ps) {
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 21bd4164385ab..d3e686c26f188 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -3641,9 +3641,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
-  // Get an MMX-sized vector type.
-  Type *getMMXVectorTy(unsigned EltSizeInBits) {
-    const unsigned X86_MMXSizeInBits = 64;
+  // Get an MMX-sized (64-bit) vector type, or optionally, other sized
+  // vectors.
+  Type *getMMXVectorTy(unsigned EltSizeInBits,
+                       unsigned X86_MMXSizeInBits = 64) {
     assert(EltSizeInBits != 0 && (X86_MMXSizeInBits % EltSizeInBits) == 0 &&
            "Illegal MMX vector element size");
     return FixedVectorType::get(IntegerType::get(*MS.C, EltSizeInBits),
@@ -3843,20 +3844,99 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
-  // Instrument multiply-add intrinsic.
-  void handleVectorPmaddIntrinsic(IntrinsicInst &I,
-                                  unsigned MMXEltSizeInBits = 0) {
-    Type *ResTy =
-        MMXEltSizeInBits ? getMMXVectorTy(MMXEltSizeInBits * 2) : I.getType();
+  // Instrument multiply-add intrinsics.
+  //
+  // e.g., Two operands:
+  //         <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
+  //         <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
+  //
+  //       Three operands are not implemented yet:
+  //         <4 x i32> @llvm.x86.avx512.vpdpbusd.128
+  //                       (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
+  //         (the result of multiply-add'ing %a and %b is accumulated with %s)
+  void handleVectorPmaddIntrinsic(IntrinsicInst &I, unsigned ReductionFactor,
+                                  unsigned EltSizeInBits = 0) {
     IRBuilder<> IRB(&I);
-    auto *Shadow0 = getShadow(&I, 0);
-    auto *Shadow1 = getShadow(&I, 1);
-    Value *S = IRB.CreateOr(Shadow0, Shadow1);
-    S = IRB.CreateBitCast(S, ResTy);
-    S = IRB.CreateSExt(IRB.CreateICmpNE(S, Constant::getNullValue(ResTy)),
-                       ResTy);
-    S = IRB.CreateBitCast(S, getShadowTy(&I));
-    setShadow(&I, S);
+
+    [[maybe_unused]] FixedVectorType *ReturnType =
+        cast<FixedVectorType>(I.getType());
+    assert(isa<FixedVectorType>(ReturnType));
+
+    assert(I.arg_size() == 2);
+
+    // Vectors A and B, and shadows
+    Value *Va = I.getOperand(0);
+    Value *Vb = I.getOperand(1);
+
+    Value *Sa = getShadow(&I, 0);
+    Value *Sb = getShadow(&I, 1);
+
+    FixedVectorType *ParamType =
+        cast<FixedVectorType>(I.getArgOperand(0)->getType());
+    assert(ParamType == I.getArgOperand(1)->getType());
+
+    assert(ParamType->getPrimitiveSizeInBits() ==
+           ReturnType->getPrimitiveSizeInBits());
+
+    // Step 1: instrument multiplication of corresponding vector elements
+    if (EltSizeInBits) {
+      ParamType = cast<FixedVectorType>(
+          getMMXVectorTy(EltSizeInBits, ParamType->getPrimitiveSizeInBits()));
+
+      Va = IRB.CreateBitCast(Va, ParamType);
+      Vb = IRB.CreateBitCast(Vb, ParamType);
+
+      Sa = IRB.CreateBitCast(Sa, getShadowTy(ParamType));
+      Sb = IRB.CreateBitCast(Sb, getShadowTy(ParamType));
+    } else {
+      assert(ParamType->getNumElements() ==
+             ReturnType->getNumElements() * ReductionFactor);
+    }
+
+    // Multiplying an *initialized* zero by an uninitialized element results in
+    // an initialized zero element.
+    //
+    // This is analogous to bitwise AND, where "AND" of 0 and a poisoned value
+    // results in an unpoisoned value. We can therefore adapt the visitAnd()
+    // instrumentation:
+    //   OutShadow =   (SaNonZero & SbNonZero)
+    //               | (VaNonZero & SbNonZero)
+    //               | (SaNonZero & VbNonZero)
+    //   where non-zero is checked on a per-element basis.
+    Value *SZero = Constant::getNullValue(Va->getType());
+    Value *VZero = Constant::getNullValue(Sa->getType());
+    Value *SaNonZero = IRB.CreateICmpNE(Sa, SZero);
+    Value *SbNonZero = IRB.CreateICmpNE(Sb, SZero);
+    Value *VaNonZero = IRB.CreateICmpNE(Va, VZero);
+    Value *VbNonZero = IRB.CreateICmpNE(Vb, VZero);
+
+    Value *SaAndSbNonZero = IRB.CreateAnd(SaNonZero, SbNonZero);
+    Value *VaAndSbNonZero = IRB.CreateAnd(VaNonZero, SbNonZero);
+    Value *SaAndVbNonZero = IRB.CreateAnd(SaNonZero, VbNonZero);
+
+    // Each element of the vector is represented by a single bit (poisoned or
+    // not) e.g., <8 x i1>.
+    Value *ComboNonZero =
+        IRB.CreateOr({SaAndSbNonZero, VaAndSbNonZero, SaAndVbNonZero});
+
+    // Extend <8 x i1> to <8 x i16>.
+    // (The real pmadd intrinsic would have computed intermediate values of
+    // <8 x i32>, but that is irrelevant for our shadow purposes because we
+    // consider each element to be either fully initialized or fully
+    // uninitialized.)
+    ComboNonZero = IRB.CreateSExt(ComboNonZero, Sa->getType());
+
+    // Step 2: instrument horizontal add
+    // e.g., collapse <8 x i16> into <4 x i16> (reduction factor == 2)
+    //                <16 x i8> into <4 x i8>  (reduction factor == 4)
+    Value *OutShadow =
+        horizontalReduce(I, ReductionFactor, ComboNonZero, nullptr);
+
+    // Extend to <4 x i32>.
+    // For MMX, cast it back to <1 x i64>.
+    OutShadow = CreateShadowCast(IRB, OutShadow, getShadowTy(&I));
+
+    setShadow(&I, OutShadow);
     setOriginForNaryOp(I);
   }
 
@@ -5391,19 +5471,46 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
       handleVectorSadIntrinsic(I);
       break;
 
+    // Multiply and Add Packed Words
+    //   < 4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16>, <8 x i16>)
+    //   < 8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>)
+    //   <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16>, <32 x i16>)
+    //
+    // Multiply and Add Packed Signed and Unsigned Bytes
+    //   < 8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8>, <16 x i8>)
+    //   <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8>, <32 x i8>)
+    //   <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8>, <64 x i8>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   < 4 x i32> @llvm.x86.avx512.mask.pmaddw.d.128
+    //                  (<8 x i16>, <8 x i16>, <4 x i32>, i8)
+    //   < 8 x i32> @llvm.x86.avx512.mask.pmaddw.d.256
+    //                  (<16 x i16>, <16 x i16>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.pmaddw.d.512
+    //                  (<32 x i16>, <32 x i16>, <16 x i32>, i16)
+    //   < 8 x i16> @llvm.x86.avx512.mask.pmaddubs.w.128
+    //                  (<16 x i8>, <16 x i8>, <8 x i16>, i8)
+    //   <16 x i16> @llvm.x86.avx512.mask.pmaddubs.w.256
+    //                  (<32 x i8>, <32 x i8>, <16 x i16>, i16)
+    //   <32 x i16> @llvm.x86.avx512.mask.pmaddubs.w.512
+    //                  (<64 x i8>, <64 x i8>, <32 x i16>, i32)
     case Intrinsic::x86_sse2_pmadd_wd:
     case Intrinsic::x86_avx2_pmadd_wd:
+    case Intrinsic::x86_avx512_pmaddw_d_512:
     case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
     case Intrinsic::x86_avx2_pmadd_ub_sw:
-      handleVectorPmaddIntrinsic(I);
+    case Intrinsic::x86_avx512_pmaddubs_w_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2);
       break;
 
+    // <1 x i64> @llvm.x86.ssse3.pmadd.ub.sw(<1 x i64>, <1 x i64>)
     case Intrinsic::x86_ssse3_pmadd_ub_sw:
-      handleVectorPmaddIntrinsic(I, 8);
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/8);
       break;
 
+    // <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64>, <1 x i64>)
     case Intrinsic::x86_mmx_pmadd_wd:
-      handleVectorPmaddIntrinsic(I, 16);
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
       break;
 
     case Intrinsic::x86_sse_cmp_ss:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll
index f916130fe53e5..26a6a3bdb5c0f 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx2-intrinsics-x86.ll
@@ -140,11 +140,21 @@ define <8 x i32> @test_x86_avx2_pmadd_wd(<16 x i16> %a0, <16 x i16> %a1) #0 {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i16>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP3:%.*]] = or <16 x i16> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <16 x i16> [[TMP3]] to <8 x i32>
-; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <8 x i32> [[TMP4]], zeroinitializer
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i1> [[TMP5]] to <8 x i32>
-; CHECK-NEXT:    [[RES:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> [[A0:%.*]], <16 x i16> [[A1:%.*]])
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne <16 x i16> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <16 x i16> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <16 x i16> [[A0:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ne <16 x i16> [[A1:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = and <16 x i1> [[TMP4]], [[TMP5]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <16 x i1> [[TMP12]], [[TMP5]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <16 x i1> [[TMP4]], [[TMP13]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <16 x i1> [[TMP11]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <16 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP7:%.*]] = sext <16 x i1> [[TMP17]] to <16 x i16>
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <16 x i16> [[TMP7]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <16 x i16> [[TMP7]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP10:%.*]] = or <8 x i16> [[TMP8]], [[TMP9]]
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <8 x i16> [[TMP10]] to <8 x i32>
+; CHECK-NEXT:    [[RES:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> [[A0]], <16 x i16> [[A1]])
 ; CHECK-NEXT:    store <8 x i32> [[TMP6]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <8 x i32> [[RES]]
 ;
@@ -677,11 +687,21 @@ define <16 x i16> @test_x86_avx2_pmadd_ub_sw(<32 x i8> %a0, <32 x i8> %a1) #0 {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <32 x i8>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <32 x i8>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP3:%.*]] = or <32 x i8> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <32 x i8> [[TMP3]] to <16 x i16>
-; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <16 x i16> [[TMP4]], zeroinitializer
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <16 x i1> [[TMP5]] to <16 x i16>
-; CHECK-NEXT:    [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0:%.*]], <32 x i8> [[A1:%.*]])
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne <32 x i8> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <32 x i8> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <32 x i8> [[A0:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ne <32 x i8> [[A1:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = and <32 x i1> [[TMP4]], [[TMP5]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <32 x i1> [[TMP12]], [[TMP5]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <32 x i1> [[TMP4]], [[TMP13]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <32 x i1> [[TMP11]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <32 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP7:%.*]] = sext <32 x i1> [[TMP17]] to <32 x i8>
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <32 x i8> [[TMP7]], <32 x i8> poison, <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30>
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <32 x i8> [[TMP7]], <32 x i8> poison, <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31>
+; CHECK-NEXT:    [[TMP10:%.*]] = or <16 x i8> [[TMP8]], [[TMP9]]
+; CHECK-NEXT:    [[TMP6:%.*]] = zext <16 x i8> [[TMP10]] to <16 x i16>
+; CHECK-NEXT:    [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0]], <32 x i8> [[A1]])
 ; CHECK-NEXT:    store <16 x i16> [[TMP6]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x i16> [[RES]]
 ;
@@ -706,11 +726,21 @@ define <16 x i16> @test_x86_avx2_pmadd_ub_sw_load_op0(ptr %ptr, <32 x i8> %a1) #
 ; CHECK-NEXT:    [[TMP6:%.*]] = xor i64 [[TMP5]], 87960930222080
 ; CHECK-NEXT:    [[TMP7:%.*]] = inttoptr i64 [[TMP6]] to ptr
 ; CHECK-NEXT:    [[_MSLD:%.*]] = load <32 x i8>, ptr [[TMP7]], align 32
-; CHECK-NEXT:    [[TMP8:%.*]] = or <32 x i8> [[_MSLD]], [[TMP2]]
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <32 x i8> [[TMP8]] to <16 x i16>
-; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <16 x i16> [[TMP9]], zeroinitializer
-; CHECK-NEXT:    [[TMP11:%.*]] = sext <16 x i1> [[TMP10]] to <16 x i16>
-; CHECK-NEXT:    [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0]], <32 x i8> [[A1:%.*]])
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp ne <32 x i8> [[_MSLD]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <32 x i8> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP17:%.*]] = icmp ne <32 x i8> [[A0]], zeroinitializer
+; CHECK-NEXT:    [[TMP18:%.*]] = icmp ne <32 x i8> [[A1:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP16:%.*]] = and <32 x i1> [[TMP9]], [[TMP10]]
+; CHECK-NEXT:    [[TMP19:%.*]] = and <32 x i1> [[TMP17]], [[TMP10]]
+; CHECK-NEXT:    [[TMP20:%.*]] = and <32 x i1> [[TMP9]], [[TMP18]]
+; CHECK-NEXT:    [[TMP21:%.*]] = or <32 x i1> [[TMP16]], [[TMP19]]
+; CHECK-NEXT:    [[TMP22:%.*]] = or <32 x i1> [[TMP21]], [[TMP20]]
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <32 x i1> [[TMP22]] to <32 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <32 x i8> [[TMP12]], <32 x i8> poison, <16 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30>
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <32 x i8> [[TMP12]], <32 x i8> poison, <16 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31>
+; CHECK-NEXT:    [[TMP15:%.*]] = or <16 x i8> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP11:%.*]] = zext <16 x i8> [[TMP15]] to <16 x i16>
+; CHECK-NEXT:    [[RES:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0]], <32 x i8> [[A1]])
 ; CHECK-NEXT:    store <16 x i16> [[TMP11]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x i16> [[RES]]
 ;
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll
index 02df9c49a010b..23a4b952281ae 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512bw-intrinsics-upgrade.ll
@@ -4930,18 +4930,22 @@ define <32 x i16> @test_int_x86_avx512_pmaddubs_w_512(<64 x i8> %x0, <64 x i8> %
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <64 x i8>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <64 x i8>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <64 x i8> [[TMP1]] to i512
-; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i512 [[TMP3]], 0
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <64 x i8> [[TMP2]] to i512
-; CHECK-NEXT:    [[_MSCMP1:%.*]] = icmp ne i512 [[TMP4]], 0
-; CHECK-NEXT:    [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
-; CHECK-NEXT:    br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF1]]
-; CHECK:       5:
-; CHECK-NEXT:    call void @__msan_warning_noreturn() #[[ATTR7]]
-; CHECK-NEXT:    unreachable
-; CHECK:       6:
-; CHECK-NEXT:    [[TMP7:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> [[X0:%.*]], <64 x i8> [[X1:%.*]])
-; CHECK-NEXT:    store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp ne <64 x i8> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp ne <64 x i8> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = icmp ne <64 x i8> [[X0:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne <64 x i8> [[X1:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP17:%.*]] = and <64 x i1> [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP8:%.*]] = and <64 x i1> [[TMP5]], [[TMP4]]
+; CHECK-NEXT:    [[TMP9:%.*]] = and <64 x i1> [[TMP3]], [[TMP6]]
+; CHECK-NEXT:    [[TMP10:%.*]] = or <64 x i1> [[TMP17]], [[TMP8]]
+; CHECK-NEXT:    [[TMP11:%.*]] = or <64 x i1> [[TMP10]], [[TMP9]]
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <64 x i1> [[TMP11]] to <64 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <64 x i8> [[TMP12]], <64 x i8> poison, <32 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14, i32 16, i32 18, i32 20, i32 22, i32 24, i32 26, i32 28, i32 30, i32 32, i32 34, i32 36, i32 38, i32 40, i32 42, i32 44, i32 46, i32 48, i32 50, i32 52, i32 54, i32 56, i32 58, i32 60, i32 62>
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <64 x i8> [[TMP12]], <64 x i8> poison, <32 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15, i32 17, i32 19, i32 21, i32 23, i32 25, i32 27, i32 29, i32 31, i32 33, i32 35, i32 37, i32 39, i32 41, i32 43, i32 45, i32 47, i32 49, i32 51, i32 53, i32 55, i32 57, i32 59, i32 61, i32 63>
+; CHECK-NEXT:    [[TMP15:%.*]] = or <32 x i8> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP16:%.*]] = zext <32 x i8> [[TMP15]] to <32 x i16>
+; CHECK-NEXT:    [[TMP7:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> [[X0]], <64 x i8> [[X1]])
+; CHECK-NEXT:    store <32 x i16> [[TMP16]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <32 x i16> [[TMP7]]
 ;
   %res = call <32 x i16> @llvm.x86.avx512.mask.pmaddubs.w.512(<64 x i8> %x0, <64 x i8> %x1, <32 x i16> %x2, i32 -1)
@@ -4955,22 +4959,26 @@ define <32 x i16> @test_int_x86_avx512_mask_pmaddubs_w_512(<64 x i8> %x0, <64 x
 ; CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 192) to ptr), align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load <32 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[TMP5:%.*]] = bi...
[truncated]

Copy link

github-actions bot commented Aug 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@thurstond thurstond marked this pull request as draft August 13, 2025 16:53
@thurstond thurstond marked this pull request as ready for review August 13, 2025 19:33
@thurstond thurstond merged commit 2b75ff1 into llvm:main Aug 15, 2025
9 checks passed
thurstond added a commit to thurstond/llvm-project that referenced this pull request Aug 16, 2025
This extends the pmadd handler (recently improved in llvm#153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions.

Updates the tests from llvm#153135
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants