Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions mlir/include/mlir/Dialect/X86Vector/X86Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -408,34 +408,41 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}];
}


//----------------------------------------------------------------------------//
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//

def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:

Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.
Convert scalar BF16 or F16 (16-bit) floating-point element stored at memory locations
starting at location `__A` to a single-precision (32-bit) floating-point,
broadcast it to packed single-precision (32-bit) floating-point elements,
and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vcvtneebf162ps";
auto elementType =
getA().getType().getElementType();
std::string intr = "llvm.x86.";
if (elementType.isBF16())
intr += "vbcstnebf162ps";
if (elementType.isF16())
intr += "vbcstnesh2ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
Expand All @@ -447,31 +454,43 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];

}

def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
//------------------------------------------------------------------------------//
// AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32
//------------------------------------------------------------------------------//

def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:

Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
Convert packed BF16 or F16 (16-bit) floating-point even-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vcvtneobf162ps";
auto elementType =
getA().getType().getElementType();
std::string intr = "llvm.x86.";
if (elementType.isBF16())
intr += "vcvtneebf162ps";
if (elementType.isF16())
intr += "vcvtneeph2ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
Expand All @@ -485,34 +504,36 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
}];
}

//----------------------------------------------------------------------------//
// AVX: Convert BF16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//

def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:

Convert scalar BF16 (16-bit) floating-point element stored at memory locations
starting at location `__A` to a single-precision (32-bit) floating-point,
broadcast it to packed single-precision (32-bit) floating-point elements,
and store the results in `dst`.
Convert packed BF16 or F16 (16-bit) floating-point odd-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
```
}];
let arguments = (ins AnyMemRef:$a);
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vbcstnebf162ps";
auto elementType =
getA().getType().getElementType();
std::string intr = "llvm.x86.";
if (elementType.isBF16())
intr += "vcvtneobf162ps";
if (elementType.isF16())
intr += "vcvtneoph2ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
Expand All @@ -521,10 +542,8 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
}
}];

let extraClassDeclaration = [{
let extraClassDeclaration = [{
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];

}

#endif // X86VECTOR_OPS
8 changes: 3 additions & 5 deletions mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,17 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
return operands;
}

SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}

SmallVector<Value>
x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}

SmallVector<Value>
x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(

void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addIllegalOp<
MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op,
CvtPackedEvenIndexedToF32Op, CvtPackedOddIndexedToF32Op,
BcstToPackedF32Op, RsqrtOp, DotOp>();
}
66 changes: 60 additions & 6 deletions mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

Expand All @@ -109,7 +109,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

Expand All @@ -118,7 +118,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

Expand All @@ -127,7 +127,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: memref<16xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

Expand All @@ -136,7 +136,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_128(
%a: memref<1xbf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

Expand All @@ -145,7 +145,61 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
%a: memref<1xbf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
%a: memref<8xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
%a: memref<16xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128
func.func @avxf16_bsct_f16_to_f32_packed_128(
%a: memref<1xf16>) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256
func.func @avxf16_bsct_f16_to_f32_packed_256(
%a: memref<1xf16>) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
return %0 : vector<8xf32>
}

Expand Down
Loading