-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][x86vector] AVX Convert/Broadcast F16 to F32 instructions #137917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: None (arun-thmn) ChangesAdds AVX broadcast and conversion from F16 to packed F32 (similar to PR: #136830). The instructions that are added:
Full diff: https://github.com/llvm/llvm-project/pull/137917.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 126fa0e352656..75b07f01e70f1 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -527,4 +527,126 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
}
+//----------------------------------------------------------------------------//
+// AVX: Convert packed F16 even-indexed/odd-indexed elements into packed F32
+//----------------------------------------------------------------------------//
+
+def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed F16 even-indexed elements into packed F32 Data.";
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed F16 (16-bit) floating-point even-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneeph2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed F16 odd-indexed elements into packed F32 Data.";
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed F16 (16-bit) floating-point odd-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneoph2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+//----------------------------------------------------------------------------//
+// AVX: Convert F16 to F32 and broadcast into packed F32
+//----------------------------------------------------------------------------//
+
+def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Broadcasts F16 into packed F32 Data.";
+
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert scalar F16 (16-bit) floating-point element stored at memory locations
+ starting at location `__A` to a single-precision (32-bit) floating-point,
+ broadcast it to packed single-precision (32-bit) floating-point elements,
+ and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vbcstnesh2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+
+}
+
#endif // X86VECTOR_OPS
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index f5e5070c74f8f..2e01a11921950 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -112,5 +112,22 @@ x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
+SmallVector<Value>
+x86vector::CvtPackedEvenIndexedF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value>
+x86vector::CvtPackedOddIndexedF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value> x86vector::BcstF16ToPackedF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 93b304c44de8e..3888ec05ad866 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -149,6 +149,60 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128
+func.func @avxf16_bsct_f16_to_f32_packed_128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256
+func.func @avxf16_bsct_f16_to_f32_packed_256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index b783cc869b981..a2fdb0cf6d457 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -154,6 +154,66 @@ func.func @avxbf16_bcst_bf16_to_f32_256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_128
+func.func @avxf16_bcst_f16_to_f32_128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_256
+func.func @avxf16_bcst_f16_to_f32_256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index a8bc180d1d0ac..f474ae281ece3 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -163,6 +163,60 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneeph2ps128
+func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneeph2ps256
+func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneoph2ps128
+func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneoph2ps256
+func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vbcstnesh2ps128
+func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vbcstnesh2ps256
+func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
|
@llvm/pr-subscribers-mlir-vector Author: None (arun-thmn) ChangesAdds AVX broadcast and conversion from F16 to packed F32 (similar to PR: #136830). The instructions that are added:
Full diff: https://github.com/llvm/llvm-project/pull/137917.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 126fa0e352656..75b07f01e70f1 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -527,4 +527,126 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
}
+//----------------------------------------------------------------------------//
+// AVX: Convert packed F16 even-indexed/odd-indexed elements into packed F32
+//----------------------------------------------------------------------------//
+
+def CvtPackedEvenIndexedF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed F16 even-indexed elements into packed F32 Data.";
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed F16 (16-bit) floating-point even-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneeph2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+def CvtPackedOddIndexedF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.f16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed F16 odd-indexed elements into packed F32 Data.";
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed F16 (16-bit) floating-point odd-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneoph2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+//----------------------------------------------------------------------------//
+// AVX: Convert F16 to F32 and broadcast into packed F32
+//----------------------------------------------------------------------------//
+
+def BcstF16ToPackedF32Op : AVX_Op<"bcst.f16_to_f32.packed", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Broadcasts F16 into packed F32 Data.";
+
+ let description = [{
+
+ #### From the Intel Intrinsics Guide:
+
+ Convert scalar F16 (16-bit) floating-point element stored at memory locations
+ starting at location `__A` to a single-precision (32-bit) floating-point,
+ broadcast it to packed single-precision (32-bit) floating-point elements,
+ and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vbcstnesh2ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+
+}
+
#endif // X86VECTOR_OPS
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index f5e5070c74f8f..2e01a11921950 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -112,5 +112,22 @@ x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
+SmallVector<Value>
+x86vector::CvtPackedEvenIndexedF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value>
+x86vector::CvtPackedOddIndexedF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value> x86vector::BcstF16ToPackedF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 93b304c44de8e..3888ec05ad866 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -149,6 +149,60 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128
+func.func @avxf16_bsct_f16_to_f32_packed_128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256
+func.func @avxf16_bsct_f16_to_f32_packed_256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index b783cc869b981..a2fdb0cf6d457 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -154,6 +154,66 @@ func.func @avxbf16_bcst_bf16_to_f32_256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
+func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_128
+func.func @avxf16_bcst_f16_to_f32_128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxf16_bcst_f16_to_f32_256
+func.func @avxf16_bcst_f16_to_f32_256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.bcst.f16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index a8bc180d1d0ac..f474ae281ece3 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -163,6 +163,60 @@ func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
return %0 : vector<8xf32>
}
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneeph2ps128
+func.func @LLVM_x86_avxf16_vcvtneeph2ps128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneeph2ps128(
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneeph2ps256
+func.func @LLVM_x86_avxf16_vcvtneeph2ps256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneeph2ps256(
+ %0 = x86vector.avx.cvt.packed.even.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vcvtneoph2ps128
+func.func @LLVM_x86_avxf16_vcvtneoph2ps128(
+ %a: memref<8xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneoph2ps128(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<8xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vcvtneoph2ps256
+func.func @LLVM_x86_avxf16_vcvtneoph2ps256(
+ %a: memref<16xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneoph2ps256(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.f16_to_f32 %a : memref<16xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxf16_vbcstnesh2ps128
+func.func @LLVM_x86_avxf16_vbcstnesh2ps128(
+ %a: memref<1xf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vbcstnesh2ps128(
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxf16_vbcstnesh2ps256
+func.func @LLVM_x86_avxf16_vbcstnesh2ps256(
+ %a: memref<1xf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vbcstnesh2ps256(
+ %0 = x86vector.avx.bcst.f16_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
|
@adam-smnk @rengolin, Please review the PR. This should be straight forward and similar like the I think., the |
AFAIK, these intrinsics are pretty much equivalent to the bf16 versions with the only difference being expected f16 type, right? It could be worth taking a moment to generalize the existing bf16 ops to cover both cases. The input memref might need to be constrained to supported data types (still any rank or unranked) to help with intrinsic name generation. That'd also help in guiding users toward correct usage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks fine but it'd be good to collapse the two bf16/f16 variants into one.
Yes, you are correct. Best option is to find a way to generalize them. Let me look those. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
@adam-smnk Tried a kind of generalization, please have a look. |
…#137917) Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: llvm#136830). The instructions that are added: - VBCSTNESH2PS - VCVTNEEPH2PS - VCVTNEOPH2PS
…#137917) Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: llvm#136830). The instructions that are added: - VBCSTNESH2PS - VCVTNEEPH2PS - VCVTNEOPH2PS
Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: #136830). The instructions that are added: