Skip to content

[msan] Handle AVX Vector Neural Network Instructions (VNNI) #153927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

thurstond
Copy link
Contributor

This extends the pmadd handler (recently improved in #153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions.

Updates the tests from #153135

This extends the pmadd handler (recently improved in llvm#153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions.

Updates the tests from llvm#153135
@llvmbot
Copy link
Member

llvmbot commented Aug 16, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Thurston Dang (thurstond)

Changes

This extends the pmadd handler (recently improved in #153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions.

Updates the tests from #153135


Patch is 266.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153927.diff

9 Files Affected:

  • (modified) llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp (+171-12)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll (+64-22)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2ni-intrinsics.ll (+85-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl_vnni-intrinsics-upgrade.ll (+473-73)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl_vnni-intrinsics.ll (+473-73)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vnni-intrinsics-upgrade.ll (+237-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vnni-intrinsics.ll (+237-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx_vnni-intrinsics.ll (+161-33)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avxvnniint8-intrinsics.ll (+165-33)
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 6b394f5338687..f3a7fc1e692b8 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -3846,7 +3846,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
-  // Instrument multiply-add intrinsics.
+  // Instrument multiply-add(-accumulate)? intrinsics.
   //
   // e.g., Two operands:
   //         <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
@@ -3854,7 +3854,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
   //       Two operands which require an EltSizeInBits override:
   //         <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
   //
-  //       Three operands are not implemented yet:
+  //       Three operands:
   //         <4 x i32> @llvm.x86.avx512.vpdpbusd.128
   //                       (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
   //         (the result of multiply-add'ing %a and %b is accumulated with %s)
@@ -3866,22 +3866,40 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
         cast<FixedVectorType>(I.getType());
     assert(isa<FixedVectorType>(ReturnType));
 
-    assert(I.arg_size() == 2);
-
     // Vectors A and B, and shadows
-    Value *Va = I.getOperand(0);
-    Value *Vb = I.getOperand(1);
+    Value *Va = nullptr;
+    Value *Vb = nullptr;
+    Value *Sa = nullptr;
+    Value *Sb = nullptr;
 
-    Value *Sa = getShadow(&I, 0);
-    Value *Sb = getShadow(&I, 1);
+    if (I.arg_size() == 2) {
+      Va = I.getOperand(0);
+      Vb = I.getOperand(1);
+
+      Sa = getShadow(&I, 0);
+      Sb = getShadow(&I, 1);
+    } else if (I.arg_size() == 3) {
+      // Operand 0 is the accumulator. We will deal with that below.
+      Va = I.getOperand(1);
+      Vb = I.getOperand(2);
+
+      Sa = getShadow(&I, 1);
+      Sb = getShadow(&I, 2);
+    } else {
+      assert(I.arg_size() == 2 || I.arg_size() == 3);
+    }
 
-    FixedVectorType *ParamType =
-        cast<FixedVectorType>(I.getArgOperand(0)->getType());
-    assert(ParamType == I.getArgOperand(1)->getType());
+    FixedVectorType *ParamType = cast<FixedVectorType>(Va->getType());
+    assert(ParamType == Vb->getType());
 
     assert(ParamType->getPrimitiveSizeInBits() ==
            ReturnType->getPrimitiveSizeInBits());
 
+    if (I.arg_size() == 3) {
+      assert(ParamType == ReturnType);
+      assert(ParamType == I.getArgOperand(0)->getType());
+    }
+
     FixedVectorType *ImplicitReturnType = ReturnType;
     // Step 1: instrument multiplication of corresponding vector elements
     if (EltSizeInBits) {
@@ -3944,10 +3962,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
                          Constant::getNullValue(Horizontal->getType())),
         ImplicitReturnType);
 
-    // For MMX, cast it back to the required fake return type (<1 x i64>).
+    // Cast it back to the required fake return type (<1 x i64>).
     if (EltSizeInBits)
       OutShadow = CreateShadowCast(IRB, OutShadow, getShadowTy(&I));
 
+    // Step 3 (if applicable): instrument accumulator
+    if (I.arg_size() == 3)
+      OutShadow = IRB.CreateOr(OutShadow, getShadow(&I, 0));
+
     setShadow(&I, OutShadow);
     setOriginForNaryOp(I);
   }
@@ -5507,6 +5529,143 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
       handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
       break;
 
+    // AVX Vector Neural Network Instructions: bytes
+    //
+    // Multiply and Add Packed Signed and Unsigned Bytes
+    //   < 4 x i32> @llvm.x86.avx512.vpdpbusd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpbusd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // Multiply and Add Unsigned and Signed Bytes With Saturation
+    //   < 4 x i32> @llvm.x86.avx512.vpdpbusds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpbusds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    //   < 4 x i32> @llvm.x86.avx2.vpdpbssd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx2.vpdpbssd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //
+    //   < 4 x i32> @llvm.x86.avx2.vpdpbssds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx2.vpdpbssds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //
+    //   <16 x i32> @llvm.x86.avx10.vpdpbssd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //   <16 x i32> @llvm.x86.avx10.vpdpbssds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpbusd.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpbusd.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpbusd.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpbusd.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpbusds.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpbusds.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpbusds.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpbusds.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    case Intrinsic::x86_avx512_vpdpbusd_128:
+    case Intrinsic::x86_avx512_vpdpbusd_256:
+    case Intrinsic::x86_avx512_vpdpbusd_512:
+    case Intrinsic::x86_avx512_vpdpbusds_128:
+    case Intrinsic::x86_avx512_vpdpbusds_256:
+    case Intrinsic::x86_avx512_vpdpbusds_512:
+    case Intrinsic::x86_avx2_vpdpbssd_128:
+    case Intrinsic::x86_avx2_vpdpbssd_256:
+    case Intrinsic::x86_avx2_vpdpbssds_128:
+    case Intrinsic::x86_avx2_vpdpbssds_256:
+    case Intrinsic::x86_avx10_vpdpbssd_512:
+    case Intrinsic::x86_avx10_vpdpbssds_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/4, /*EltSize=*/8);
+      break;
+
+    // AVX Vector Neural Network Instructions: words
+    //
+    // Multiply and Add Signed Word Integers
+    //   < 4 x i32> @llvm.x86.avx512.vpdpwssd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpwssd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpwssd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // Multiply and Add Signed Word Integers With Saturation
+    //   < 4 x i32> @llvm.x86.avx512.vpdpwssds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpwssds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpwssds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpwssd.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpwssd.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpwssd.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpwssd.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpwssd.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpwssd.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpwssds.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpwssds.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpwssds.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpwssds.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpwssds.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpwssds.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    case Intrinsic::x86_avx512_vpdpwssd_128:
+    case Intrinsic::x86_avx512_vpdpwssd_256:
+    case Intrinsic::x86_avx512_vpdpwssd_512:
+    case Intrinsic::x86_avx512_vpdpwssds_128:
+    case Intrinsic::x86_avx512_vpdpwssds_256:
+    case Intrinsic::x86_avx512_vpdpwssds_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
+      break;
+
+      // TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
+      // Precision
+      //   <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
+      //                   (<4 x float>, <8 x bfloat>, <8 x bfloat>)
+      //   <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
+      //                   (<8 x float>, <16 x bfloat>, <16 x bfloat>)
+      //   <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
+      //                   (<16 x float>, <32 x bfloat>, <32 x bfloat>)
+      // handleVectorPmaddIntrinsic() currently only handles integer types.
+
     case Intrinsic::x86_sse_cmp_ss:
     case Intrinsic::x86_sse2_cmp_sd:
     case Intrinsic::x86_sse_comieq_ss:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
index 7af8f34d403a0..298dc4b2c853a 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
@@ -7,19 +7,7 @@
 ; - llvm.x86.avx10.vdpphps.512
 ; - llvm.x86.avx10.vmpsadbw.512
 ;
-; Handled heuristically:
-; - llvm.x86.avx10.vpdpbssd.512
-; - llvm.x86.avx10.vpdpbssds.512
-; - llvm.x86.avx10.vpdpbsud.512
-; - llvm.x86.avx10.vpdpbsuds.512
-; - llvm.x86.avx10.vpdpbuud.512
-; - llvm.x86.avx10.vpdpbuuds.512
-; - llvm.x86.avx10.vpdpwsud.512
-; - llvm.x86.avx10.vpdpwsuds.512
-; - llvm.x86.avx10.vpdpwusd.512
-; - llvm.x86.avx10.vpdpwusds.512
-; - llvm.x86.avx10.vpdpwuud.512
-; - llvm.x86.avx10.vpdpwuuds.512
+; Handled heuristically: (none)
 
 target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
@@ -140,8 +128,8 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 ; CHECK-LABEL: define <16 x i32> @test_mm512_dpbssd_epi32(
 ; CHECK-SAME: <16 x i32> [[__W:%.*]], <16 x i32> [[__A:%.*]], ptr [[PB:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load i64, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
-; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
+; CHECK-NEXT:    [[TMP4:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
 ; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i64 [[TMP1]], 0
 ; CHECK-NEXT:    br i1 [[_MSCMP]], label %[[BB4:.*]], label %[[BB5:.*]], !prof [[PROF1]]
@@ -154,8 +142,26 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 ; CHECK-NEXT:    [[TMP7:%.*]] = xor i64 [[TMP6]], 87960930222080
 ; CHECK-NEXT:    [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr
 ; CHECK-NEXT:    [[_MSLD:%.*]] = load <16 x i32>, ptr [[TMP8]], align 64
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP2]], [[TMP3]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[_MSLD]]
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <16 x i32> [[_MSLD]] to <64 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ne <64 x i8> [[TMP11]], zeroinitializer
+; CHECK-NEXT:    [[TMP14:%.*]] = icmp ne <64 x i8> [[TMP12]], zeroinitializer
+; CHECK-NEXT:    [[TMP15:%.*]] = icmp ne <64 x i8> [[TMP9]], zeroinitializer
+; CHECK-NEXT:    [[TMP16:%.*]] = icmp ne <64 x i8> [[TMP10]], zeroinitializer
+; CHECK-NEXT:    [[TMP17:%.*]] = and <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP18:%.*]] = and <64 x i1> [[TMP15]], [[TMP14]]
+; CHECK-NEXT:    [[TMP19:%.*]] = and <64 x i1> [[TMP13]], [[TMP16]]
+; CHECK-NEXT:    [[TMP20:%.*]] = or <64 x i1> [[TMP17]], [[TMP18]]
+; CHECK-NEXT:    [[TMP21:%.*]] = or <64 x i1> [[TMP20]], [[TMP19]]
+; CHECK-NEXT:    [[TMP22:%.*]] = sext <64 x i1> [[TMP21]] to <64 x i8>
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast <64 x i8> [[TMP22]] to <32 x i16>
+; CHECK-NEXT:    [[TMP24:%.*]] = icmp ne <32 x i16> [[TMP23]], zeroinitializer
+; CHECK-NEXT:    [[TMP25:%.*]] = sext <32 x i1> [[TMP24]] to <32 x i16>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <32 x i16> [[TMP25]] to i512
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast i512 [[TMP26]] to <16 x i32>
+; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[TMP27]], [[TMP4]]
 ; CHECK-NEXT:    [[RES:%.*]] = tail call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(<16 x i32> [[__W]], <16 x i32> [[__A]], <16 x i32> [[__B]])
 ; CHECK-NEXT:    store <16 x i32> [[_MSPROP1]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x i32> [[RES]]
@@ -168,13 +174,31 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 define <16 x i32> @test_mm512_mask_dpbssds_epi32(<16 x i32> %__W, i16 zeroext %__U, <16 x i32> %__A, <16 x i32> %__B) sanitize_memory {
 ; CHECK-LABEL: define <16 x i32> @test_mm512_mask_dpbssds_epi32(
 ; CHECK-SAME: <16 x i32> [[__W:%.*]], i16 zeroext [[__U:%.*]], <16 x i32> [[__A:%.*]], <16 x i32> [[__B:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 72) to ptr), align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 136) to ptr), align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[TMP3]]
+; CHECK-NEXT:    [[TMP24:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP25:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <16 x i32> [[TMP2]] to <64 x i8>
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP28:%.*]] = icmp ne <64 x i8> [[TMP26]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <64 x i8> [[TMP27]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp ne <64 x i8> [[TMP24]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <64 x i8> [[TMP25]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = and <64 x i1> [[TMP28]], [[TMP10]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <64 x i1> [[TMP11]], [[TMP10]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <64 x i1> [[TMP28]], [[TMP12]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <64 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = sext <64 x i1> [[TMP17]] to <64 x i8>
+; CHECK-NEXT:    [[TMP19:%.*]] = bitcast <64 x i8> [[TMP18]] to <32 x i16>
+; CHECK-NEXT:    [[TMP20:%.*]] = icmp ne <32 x i16> [[TMP19]], zeroinitializer
+; CHECK-NEXT:    [[TMP21:%.*]] = sext <32 x i1> [[TMP20]] to <32 x i16>
+; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <32 x i16> [[TMP21]] to i512
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast i512 [[TMP22]] to <16 x i32>
+; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[TMP23]], [[TMP1]]
 ; CHECK-NEXT:    [[DPI:%.*]] = tail call <16 x i32> @llvm.x86.avx10.vpdpbssds.512(<16 x i32> [[__W]], <16 x i32> [[__A]], <16 x i32> [[__B]])
 ; CHECK-NEXT:    [[TMP5:%.*]] = bitcast i16 [[TMP4]] to <16 x i1>
 ; CHECK-NEXT:    [[BST:%.*]] = bitcast i16 [[__U]] to <16 x i1>
@@ -196,13 +220,31 @@ define <16 x i32> @test_mm512_mask_dpbssds_epi32(<16 x i32> %__W, i16 zeroext %_
 define <16 x i32> @test_mm512_maskz_dpbssd_epi32(i16 zeroext %__U, <16 x i32> %__W, <16 x i32> %__A, <16 x i32> %__B) sanitize_memory {
 ; CHECK-LABEL: define <16 x i32> @test_mm512_maskz_dpbssd_epi32(
 ; CHECK-SAME: i16 zeroext [[__U:%.*]], <16 x i32> [[__W:%.*]], <16 x i32> [[__A:%.*]], <16 x i32> [[__B:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 8) to ptr), align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 72) to ptr), align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 136) to ptr), align 8
+; CHECK-NEXT:    [[TMP24:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 8) to ptr), align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[TMP3]]
+; CHECK-NEXT:    [[TMP25:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <16 x i32> [[TMP2]] to <64 x i8>
+; CHECK-NEXT:    [[TMP28:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP29:%.*]] = icmp ne <64 x i8> [[TMP27]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <64 x i8> [[TMP28]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp ne <64 x i8> [[TMP25]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <64 x i8> [[TMP26]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = and <64 x i1> [[TMP29]], [[TMP10]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <64 x i1> [[TMP11]], [[TMP10]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <64 x i1> [[TMP29]], [[TMP12]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <64 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = sext <64 x i1> [[TMP17]] to <64 x i8>
+; CHECK-NEXT:    [[TMP19:%.*]] = bitcast <64 x i8> [[TMP18]] to <32 x i16>
+; CHECK-NEXT:    [[TMP20:%.*]] = icmp ne <32 x i16> [[TMP19]], zeroinitializer
+; CHECK-NEXT:    [[TMP21:%.*]] = sext <32 x i1> [[TMP20]] to <32 x i16>
+; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <32 x i16> [[TMP21]] to i512
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 16, 2025

@llvm/pr-subscribers-compiler-rt-sanitizer

Author: Thurston Dang (thurstond)

Changes

This extends the pmadd handler (recently improved in #153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions.

Updates the tests from #153135


Patch is 266.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153927.diff

9 Files Affected:

  • (modified) llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp (+171-12)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll (+64-22)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2ni-intrinsics.ll (+85-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl_vnni-intrinsics-upgrade.ll (+473-73)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl_vnni-intrinsics.ll (+473-73)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vnni-intrinsics-upgrade.ll (+237-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vnni-intrinsics.ll (+237-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx_vnni-intrinsics.ll (+161-33)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avxvnniint8-intrinsics.ll (+165-33)
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 6b394f5338687..f3a7fc1e692b8 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -3846,7 +3846,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
-  // Instrument multiply-add intrinsics.
+  // Instrument multiply-add(-accumulate)? intrinsics.
   //
   // e.g., Two operands:
   //         <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
@@ -3854,7 +3854,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
   //       Two operands which require an EltSizeInBits override:
   //         <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
   //
-  //       Three operands are not implemented yet:
+  //       Three operands:
   //         <4 x i32> @llvm.x86.avx512.vpdpbusd.128
   //                       (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
   //         (the result of multiply-add'ing %a and %b is accumulated with %s)
@@ -3866,22 +3866,40 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
         cast<FixedVectorType>(I.getType());
     assert(isa<FixedVectorType>(ReturnType));
 
-    assert(I.arg_size() == 2);
-
     // Vectors A and B, and shadows
-    Value *Va = I.getOperand(0);
-    Value *Vb = I.getOperand(1);
+    Value *Va = nullptr;
+    Value *Vb = nullptr;
+    Value *Sa = nullptr;
+    Value *Sb = nullptr;
 
-    Value *Sa = getShadow(&I, 0);
-    Value *Sb = getShadow(&I, 1);
+    if (I.arg_size() == 2) {
+      Va = I.getOperand(0);
+      Vb = I.getOperand(1);
+
+      Sa = getShadow(&I, 0);
+      Sb = getShadow(&I, 1);
+    } else if (I.arg_size() == 3) {
+      // Operand 0 is the accumulator. We will deal with that below.
+      Va = I.getOperand(1);
+      Vb = I.getOperand(2);
+
+      Sa = getShadow(&I, 1);
+      Sb = getShadow(&I, 2);
+    } else {
+      assert(I.arg_size() == 2 || I.arg_size() == 3);
+    }
 
-    FixedVectorType *ParamType =
-        cast<FixedVectorType>(I.getArgOperand(0)->getType());
-    assert(ParamType == I.getArgOperand(1)->getType());
+    FixedVectorType *ParamType = cast<FixedVectorType>(Va->getType());
+    assert(ParamType == Vb->getType());
 
     assert(ParamType->getPrimitiveSizeInBits() ==
            ReturnType->getPrimitiveSizeInBits());
 
+    if (I.arg_size() == 3) {
+      assert(ParamType == ReturnType);
+      assert(ParamType == I.getArgOperand(0)->getType());
+    }
+
     FixedVectorType *ImplicitReturnType = ReturnType;
     // Step 1: instrument multiplication of corresponding vector elements
     if (EltSizeInBits) {
@@ -3944,10 +3962,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
                          Constant::getNullValue(Horizontal->getType())),
         ImplicitReturnType);
 
-    // For MMX, cast it back to the required fake return type (<1 x i64>).
+    // Cast it back to the required fake return type (<1 x i64>).
     if (EltSizeInBits)
       OutShadow = CreateShadowCast(IRB, OutShadow, getShadowTy(&I));
 
+    // Step 3 (if applicable): instrument accumulator
+    if (I.arg_size() == 3)
+      OutShadow = IRB.CreateOr(OutShadow, getShadow(&I, 0));
+
     setShadow(&I, OutShadow);
     setOriginForNaryOp(I);
   }
@@ -5507,6 +5529,143 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
       handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
       break;
 
+    // AVX Vector Neural Network Instructions: bytes
+    //
+    // Multiply and Add Packed Signed and Unsigned Bytes
+    //   < 4 x i32> @llvm.x86.avx512.vpdpbusd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpbusd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // Multiply and Add Unsigned and Signed Bytes With Saturation
+    //   < 4 x i32> @llvm.x86.avx512.vpdpbusds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpbusds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    //   < 4 x i32> @llvm.x86.avx2.vpdpbssd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx2.vpdpbssd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //
+    //   < 4 x i32> @llvm.x86.avx2.vpdpbssds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx2.vpdpbssds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //
+    //   <16 x i32> @llvm.x86.avx10.vpdpbssd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //   <16 x i32> @llvm.x86.avx10.vpdpbssds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpbusd.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpbusd.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpbusd.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpbusd.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpbusds.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpbusds.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpbusds.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpbusds.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    case Intrinsic::x86_avx512_vpdpbusd_128:
+    case Intrinsic::x86_avx512_vpdpbusd_256:
+    case Intrinsic::x86_avx512_vpdpbusd_512:
+    case Intrinsic::x86_avx512_vpdpbusds_128:
+    case Intrinsic::x86_avx512_vpdpbusds_256:
+    case Intrinsic::x86_avx512_vpdpbusds_512:
+    case Intrinsic::x86_avx2_vpdpbssd_128:
+    case Intrinsic::x86_avx2_vpdpbssd_256:
+    case Intrinsic::x86_avx2_vpdpbssds_128:
+    case Intrinsic::x86_avx2_vpdpbssds_256:
+    case Intrinsic::x86_avx10_vpdpbssd_512:
+    case Intrinsic::x86_avx10_vpdpbssds_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/4, /*EltSize=*/8);
+      break;
+
+    // AVX Vector Neural Network Instructions: words
+    //
+    // Multiply and Add Signed Word Integers
+    //   < 4 x i32> @llvm.x86.avx512.vpdpwssd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpwssd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpwssd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // Multiply and Add Signed Word Integers With Saturation
+    //   < 4 x i32> @llvm.x86.avx512.vpdpwssds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpwssds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpwssds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpwssd.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpwssd.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpwssd.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpwssd.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpwssd.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpwssd.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpwssds.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpwssds.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpwssds.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpwssds.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpwssds.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpwssds.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    case Intrinsic::x86_avx512_vpdpwssd_128:
+    case Intrinsic::x86_avx512_vpdpwssd_256:
+    case Intrinsic::x86_avx512_vpdpwssd_512:
+    case Intrinsic::x86_avx512_vpdpwssds_128:
+    case Intrinsic::x86_avx512_vpdpwssds_256:
+    case Intrinsic::x86_avx512_vpdpwssds_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
+      break;
+
+      // TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
+      // Precision
+      //   <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
+      //                   (<4 x float>, <8 x bfloat>, <8 x bfloat>)
+      //   <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
+      //                   (<8 x float>, <16 x bfloat>, <16 x bfloat>)
+      //   <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
+      //                   (<16 x float>, <32 x bfloat>, <32 x bfloat>)
+      // handleVectorPmaddIntrinsic() currently only handles integer types.
+
     case Intrinsic::x86_sse_cmp_ss:
     case Intrinsic::x86_sse2_cmp_sd:
     case Intrinsic::x86_sse_comieq_ss:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
index 7af8f34d403a0..298dc4b2c853a 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
@@ -7,19 +7,7 @@
 ; - llvm.x86.avx10.vdpphps.512
 ; - llvm.x86.avx10.vmpsadbw.512
 ;
-; Handled heuristically:
-; - llvm.x86.avx10.vpdpbssd.512
-; - llvm.x86.avx10.vpdpbssds.512
-; - llvm.x86.avx10.vpdpbsud.512
-; - llvm.x86.avx10.vpdpbsuds.512
-; - llvm.x86.avx10.vpdpbuud.512
-; - llvm.x86.avx10.vpdpbuuds.512
-; - llvm.x86.avx10.vpdpwsud.512
-; - llvm.x86.avx10.vpdpwsuds.512
-; - llvm.x86.avx10.vpdpwusd.512
-; - llvm.x86.avx10.vpdpwusds.512
-; - llvm.x86.avx10.vpdpwuud.512
-; - llvm.x86.avx10.vpdpwuuds.512
+; Handled heuristically: (none)
 
 target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
@@ -140,8 +128,8 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 ; CHECK-LABEL: define <16 x i32> @test_mm512_dpbssd_epi32(
 ; CHECK-SAME: <16 x i32> [[__W:%.*]], <16 x i32> [[__A:%.*]], ptr [[PB:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load i64, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
-; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
+; CHECK-NEXT:    [[TMP4:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
 ; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i64 [[TMP1]], 0
 ; CHECK-NEXT:    br i1 [[_MSCMP]], label %[[BB4:.*]], label %[[BB5:.*]], !prof [[PROF1]]
@@ -154,8 +142,26 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 ; CHECK-NEXT:    [[TMP7:%.*]] = xor i64 [[TMP6]], 87960930222080
 ; CHECK-NEXT:    [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr
 ; CHECK-NEXT:    [[_MSLD:%.*]] = load <16 x i32>, ptr [[TMP8]], align 64
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP2]], [[TMP3]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[_MSLD]]
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <16 x i32> [[_MSLD]] to <64 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ne <64 x i8> [[TMP11]], zeroinitializer
+; CHECK-NEXT:    [[TMP14:%.*]] = icmp ne <64 x i8> [[TMP12]], zeroinitializer
+; CHECK-NEXT:    [[TMP15:%.*]] = icmp ne <64 x i8> [[TMP9]], zeroinitializer
+; CHECK-NEXT:    [[TMP16:%.*]] = icmp ne <64 x i8> [[TMP10]], zeroinitializer
+; CHECK-NEXT:    [[TMP17:%.*]] = and <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP18:%.*]] = and <64 x i1> [[TMP15]], [[TMP14]]
+; CHECK-NEXT:    [[TMP19:%.*]] = and <64 x i1> [[TMP13]], [[TMP16]]
+; CHECK-NEXT:    [[TMP20:%.*]] = or <64 x i1> [[TMP17]], [[TMP18]]
+; CHECK-NEXT:    [[TMP21:%.*]] = or <64 x i1> [[TMP20]], [[TMP19]]
+; CHECK-NEXT:    [[TMP22:%.*]] = sext <64 x i1> [[TMP21]] to <64 x i8>
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast <64 x i8> [[TMP22]] to <32 x i16>
+; CHECK-NEXT:    [[TMP24:%.*]] = icmp ne <32 x i16> [[TMP23]], zeroinitializer
+; CHECK-NEXT:    [[TMP25:%.*]] = sext <32 x i1> [[TMP24]] to <32 x i16>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <32 x i16> [[TMP25]] to i512
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast i512 [[TMP26]] to <16 x i32>
+; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[TMP27]], [[TMP4]]
 ; CHECK-NEXT:    [[RES:%.*]] = tail call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(<16 x i32> [[__W]], <16 x i32> [[__A]], <16 x i32> [[__B]])
 ; CHECK-NEXT:    store <16 x i32> [[_MSPROP1]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x i32> [[RES]]
@@ -168,13 +174,31 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 define <16 x i32> @test_mm512_mask_dpbssds_epi32(<16 x i32> %__W, i16 zeroext %__U, <16 x i32> %__A, <16 x i32> %__B) sanitize_memory {
 ; CHECK-LABEL: define <16 x i32> @test_mm512_mask_dpbssds_epi32(
 ; CHECK-SAME: <16 x i32> [[__W:%.*]], i16 zeroext [[__U:%.*]], <16 x i32> [[__A:%.*]], <16 x i32> [[__B:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 72) to ptr), align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 136) to ptr), align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[TMP3]]
+; CHECK-NEXT:    [[TMP24:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP25:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <16 x i32> [[TMP2]] to <64 x i8>
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP28:%.*]] = icmp ne <64 x i8> [[TMP26]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <64 x i8> [[TMP27]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp ne <64 x i8> [[TMP24]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <64 x i8> [[TMP25]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = and <64 x i1> [[TMP28]], [[TMP10]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <64 x i1> [[TMP11]], [[TMP10]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <64 x i1> [[TMP28]], [[TMP12]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <64 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = sext <64 x i1> [[TMP17]] to <64 x i8>
+; CHECK-NEXT:    [[TMP19:%.*]] = bitcast <64 x i8> [[TMP18]] to <32 x i16>
+; CHECK-NEXT:    [[TMP20:%.*]] = icmp ne <32 x i16> [[TMP19]], zeroinitializer
+; CHECK-NEXT:    [[TMP21:%.*]] = sext <32 x i1> [[TMP20]] to <32 x i16>
+; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <32 x i16> [[TMP21]] to i512
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast i512 [[TMP22]] to <16 x i32>
+; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[TMP23]], [[TMP1]]
 ; CHECK-NEXT:    [[DPI:%.*]] = tail call <16 x i32> @llvm.x86.avx10.vpdpbssds.512(<16 x i32> [[__W]], <16 x i32> [[__A]], <16 x i32> [[__B]])
 ; CHECK-NEXT:    [[TMP5:%.*]] = bitcast i16 [[TMP4]] to <16 x i1>
 ; CHECK-NEXT:    [[BST:%.*]] = bitcast i16 [[__U]] to <16 x i1>
@@ -196,13 +220,31 @@ define <16 x i32> @test_mm512_mask_dpbssds_epi32(<16 x i32> %__W, i16 zeroext %_
 define <16 x i32> @test_mm512_maskz_dpbssd_epi32(i16 zeroext %__U, <16 x i32> %__W, <16 x i32> %__A, <16 x i32> %__B) sanitize_memory {
 ; CHECK-LABEL: define <16 x i32> @test_mm512_maskz_dpbssd_epi32(
 ; CHECK-SAME: i16 zeroext [[__U:%.*]], <16 x i32> [[__W:%.*]], <16 x i32> [[__A:%.*]], <16 x i32> [[__B:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 8) to ptr), align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 72) to ptr), align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 136) to ptr), align 8
+; CHECK-NEXT:    [[TMP24:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 8) to ptr), align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[TMP3]]
+; CHECK-NEXT:    [[TMP25:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <16 x i32> [[TMP2]] to <64 x i8>
+; CHECK-NEXT:    [[TMP28:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP29:%.*]] = icmp ne <64 x i8> [[TMP27]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <64 x i8> [[TMP28]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp ne <64 x i8> [[TMP25]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <64 x i8> [[TMP26]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = and <64 x i1> [[TMP29]], [[TMP10]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <64 x i1> [[TMP11]], [[TMP10]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <64 x i1> [[TMP29]], [[TMP12]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <64 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = sext <64 x i1> [[TMP17]] to <64 x i8>
+; CHECK-NEXT:    [[TMP19:%.*]] = bitcast <64 x i8> [[TMP18]] to <32 x i16>
+; CHECK-NEXT:    [[TMP20:%.*]] = icmp ne <32 x i16> [[TMP19]], zeroinitializer
+; CHECK-NEXT:    [[TMP21:%.*]] = sext <32 x i1> [[TMP20]] to <32 x i16>
+; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <32 x i16> [[TMP21]] to i512
...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants