diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 126fa0e352656..4f8301f9380b8 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -408,34 +408,41 @@ def DotOp : AVX_LowOp<"dot", [Pure, }]; } - //----------------------------------------------------------------------------// -// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32 +// AVX: Convert BF16/F16 to F32 and broadcast into packed F32 //----------------------------------------------------------------------------// -def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>, +def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>, DeclareOpInterfaceMethods]> { - let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data."; + let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data."; let description = [{ #### From the Intel Intrinsics Guide: - Convert packed BF16 (16-bit) floating-point even-indexed elements stored at - memory locations starting at location `__A` to packed single-precision - (32-bit) floating-point elements, and store the results in `dst`. + Convert scalar BF16 or F16 (16-bit) floating-point element stored at memory locations + starting at location `__A` to a single-precision (32-bit) floating-point, + broadcast it to packed single-precision (32-bit) floating-point elements, + and store the results in `dst`. Example: ```mlir - %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> ``` }]; - let arguments = (ins AnyMemRef:$a); + let arguments = (ins MemRefOf<[BF16, F16]>:$a); let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); let assemblyFormat = "$a attr-dict`:` type($a)`->` type($dst)"; let extraClassDefinition = [{ std::string $cppClass::getIntrinsicName() { - std::string intr = "llvm.x86.vcvtneebf162ps"; + auto elementType = + getA().getType().getElementType(); + std::string intr = "llvm.x86."; + if (elementType.isBF16()) + intr += "vbcstnebf162ps"; + if (elementType.isF16()) + intr += "vbcstnesh2ps"; VectorType vecType = getDst().getType(); unsigned elemBitWidth = vecType.getElementTypeBitWidth(); unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; @@ -447,31 +454,43 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3 let extraClassDeclaration = [{ SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); }]; + } -def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>, +//------------------------------------------------------------------------------// +// AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32 +//------------------------------------------------------------------------------// + +def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>, DeclareOpInterfaceMethods]> { - let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data."; + let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data."; let description = [{ #### From the Intel Intrinsics Guide: - Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at + Convert packed BF16 or F16 (16-bit) floating-point even-indexed elements stored at memory locations starting at location `__A` to packed single-precision (32-bit) floating-point elements, and store the results in `dst`. Example: ```mlir - %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> ``` }]; - let arguments = (ins AnyMemRef:$a); + let arguments = (ins MemRefOf<[BF16, F16]>:$a); let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); let assemblyFormat = "$a attr-dict`:` type($a)`->` type($dst)"; let extraClassDefinition = [{ std::string $cppClass::getIntrinsicName() { - std::string intr = "llvm.x86.vcvtneobf162ps"; + auto elementType = + getA().getType().getElementType(); + std::string intr = "llvm.x86."; + if (elementType.isBF16()) + intr += "vcvtneebf162ps"; + if (elementType.isF16()) + intr += "vcvtneeph2ps"; VectorType vecType = getDst().getType(); unsigned elemBitWidth = vecType.getElementTypeBitWidth(); unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; @@ -485,34 +504,36 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32" }]; } -//----------------------------------------------------------------------------// -// AVX: Convert BF16 to F32 and broadcast into packed F32 -//----------------------------------------------------------------------------// - -def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>, +def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>, DeclareOpInterfaceMethods]> { - let summary = "AVX: Broadcasts BF16 into packed F32 Data."; + let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data."; let description = [{ #### From the Intel Intrinsics Guide: - Convert scalar BF16 (16-bit) floating-point element stored at memory locations - starting at location `__A` to a single-precision (32-bit) floating-point, - broadcast it to packed single-precision (32-bit) floating-point elements, - and store the results in `dst`. + Convert packed BF16 or F16 (16-bit) floating-point odd-indexed elements stored at + memory locations starting at location `__A` to packed single-precision + (32-bit) floating-point elements, and store the results in `dst`. Example: ```mlir - %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> ``` }]; - let arguments = (ins AnyMemRef:$a); + let arguments = (ins MemRefOf<[BF16, F16]>:$a); let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst); let assemblyFormat = "$a attr-dict`:` type($a)`->` type($dst)"; let extraClassDefinition = [{ std::string $cppClass::getIntrinsicName() { - std::string intr = "llvm.x86.vbcstnebf162ps"; + auto elementType = + getA().getType().getElementType(); + std::string intr = "llvm.x86."; + if (elementType.isBF16()) + intr += "vcvtneobf162ps"; + if (elementType.isF16()) + intr += "vcvtneoph2ps"; VectorType vecType = getDst().getType(); unsigned elemBitWidth = vecType.getElementTypeBitWidth(); unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; @@ -521,10 +542,8 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me } }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = [{ SmallVector getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&); }]; - } - #endif // X86VECTOR_OPS diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index f5e5070c74f8f..8d383b1f8103b 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -95,19 +95,17 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter, return operands; } -SmallVector x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands( +SmallVector x86vector::BcstToPackedF32Op::getIntrinsicOperands( RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } -SmallVector -x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands( +SmallVector x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands( RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } -SmallVector -x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands( +SmallVector x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands( RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) { return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter); } diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index d2297554a1012..9ee44a63ba2e4 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -114,8 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns( void mlir::configureX86VectorLegalizeForExportTarget( LLVMConversionTarget &target) { - target.addIllegalOp< - MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op, - CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op, - CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>(); + target.addIllegalOp(); } diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index 93b304c44de8e..63f06624ef897 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -100,7 +100,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128" - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -109,7 +109,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256" - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -118,7 +118,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128" - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -127,7 +127,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256" - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -136,7 +136,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_128( %a: memref<1xbf16>) -> vector<4xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128" - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -145,7 +145,61 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256( %a: memref<1xbf16>) -> vector<8xf32> { // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256" - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128 +func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128" + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256 +func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256" + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128 +func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128" + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256 +func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256" + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128 +func.func @avxf16_bsct_f16_to_f32_packed_128( + %a: memref<1xf16>) -> vector<4xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128" + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256 +func.func @avxf16_bsct_f16_to_f32_packed_256( + %a: memref<1xf16>) -> vector<8xf32> +{ + // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256" + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> return %0 : vector<8xf32> } diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir index b783cc869b981..7dcab3eb4dcb8 100644 --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -98,9 +98,9 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512( func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( %a: memref<8xbf16>) -> vector<4xf32> { - // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<8xbf16> -> vector<4xf32> - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -108,9 +108,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128( func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( %a: memref<16xbf16>) -> vector<8xf32> { - // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<16xbf16> -> vector<8xf32> - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -118,9 +118,9 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256( func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( %a: memref<8xbf16>) -> vector<4xf32> { - // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<8xbf16> -> vector<4xf32> - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -128,9 +128,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128( func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( %a: memref<16xbf16>) -> vector<8xf32> { - // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} : + // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : // CHECK-SAME: memref<16xbf16> -> vector<8xf32> - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -138,9 +138,9 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256( func.func @avxbf16_bcst_bf16_to_f32_128( %a: memref<1xbf16>) -> vector<4xf32> { - // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} : + // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK-SAME: memref<1xbf16> -> vector<4xf32> - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -148,9 +148,69 @@ func.func @avxbf16_bcst_bf16_to_f32_128( func.func @avxbf16_bcst_bf16_to_f32_256( %a: memref<1xbf16>) -> vector<8xf32> { - // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} : + // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : // CHECK-SAME: memref<1xbf16> -> vector<8xf32> - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128 +func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : + // CHECK-SAME: memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256 +func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 {{.*}} : + // CHECK-SAME: memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128 +func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : + // CHECK-SAME: memref<8xf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256 +func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 {{.*}} : + // CHECK-SAME: memref<16xf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_128 +func.func @avxf16_bcst_f16_to_f32_128( + %a: memref<1xf16>) -> vector<4xf32> +{ + // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : + // CHECK-SAME: memref<1xf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_256 +func.func @avxf16_bcst_f16_to_f32_256( + %a: memref<1xf16>) -> vector<8xf32> +{ + // CHECK: x86vector.avx.bcst_to_f32.packed {{.*}} : + // CHECK-SAME: memref<1xf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> return %0 : vector<8xf32> } diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index a8bc180d1d0ac..d11dc89bdc7c9 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -114,7 +114,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps128( %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128( - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -123,7 +123,7 @@ func.func @LLVM_x86_avxbf16_vcvtneebf162ps256( %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256( - %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -132,7 +132,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps128( %a: memref<8xbf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128( - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -141,7 +141,7 @@ func.func @LLVM_x86_avxbf16_vcvtneobf162ps256( %a: memref<16xbf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256( - %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32> + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -150,7 +150,7 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps128( %a: memref<1xbf16>) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128( - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -159,7 +159,61 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256( %a: memref<1xbf16>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256( - %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneeph2ps128 +func.func @LLVM_x86_avxf16_vcvtneeph2ps128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128( + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneeph2ps256 +func.func @LLVM_x86_avxf16_vcvtneeph2ps256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256( + %0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneoph2ps128 +func.func @LLVM_x86_avxf16_vcvtneoph2ps128( + %a: memref<8xf16>) -> vector<4xf32> +{ + // CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128( + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneoph2ps256 +func.func @LLVM_x86_avxf16_vcvtneoph2ps256( + %a: memref<16xf16>) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256( + %0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vbcstnesh2ps128 +func.func @LLVM_x86_avxf16_vbcstnesh2ps128( + %a: memref<1xf16>) -> vector<4xf32> +{ + // CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128( + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vbcstnesh2ps256 +func.func @LLVM_x86_avxf16_vbcstnesh2ps256( + %a: memref<1xf16>) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256( + %0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32> return %0 : vector<8xf32> }